From 8d43212bc617c80620a29273595c56cdcfdbe7bc Mon Sep 17 00:00:00 2001 From: George Hotz Date: Thu, 1 Jan 2026 18:40:58 -0500 Subject: [PATCH] assembly/amd: start work on SQTT parsing/emulation --- extra/assembly/amd/emu.py | 159 ++++++++ extra/assembly/amd/sqtt.py | 394 +++++++++++++++++++ extra/assembly/amd/test/discover_instops.py | 190 +++++++++ extra/assembly/amd/test/test_sqtt.py | 79 ++++ extra/assembly/amd/test/test_sqtt_compare.py | 156 ++++++++ extra/assembly/amd/test/test_sqtt_hw.py | 276 +++++++++++++ extra/assembly/amd/test/test_sqtt_ops.py | 204 ++++++++++ tinygrad/runtime/ops_amd.py | 7 +- 8 files changed, 1462 insertions(+), 3 deletions(-) create mode 100644 extra/assembly/amd/sqtt.py create mode 100644 extra/assembly/amd/test/discover_instops.py create mode 100644 extra/assembly/amd/test/test_sqtt.py create mode 100644 extra/assembly/amd/test/test_sqtt_compare.py create mode 100644 extra/assembly/amd/test/test_sqtt_hw.py create mode 100644 extra/assembly/amd/test/test_sqtt_ops.py diff --git a/extra/assembly/amd/emu.py b/extra/assembly/amd/emu.py index dd6395cf30..f92ed84a8c 100644 --- a/extra/assembly/amd/emu.py +++ b/extra/assembly/amd/emu.py @@ -2,6 +2,7 @@ # mypy: ignore-errors from __future__ import annotations import ctypes +from enum import IntEnum from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64 from extra.assembly.amd.pcode import Reg from extra.assembly.amd.asm import detect_format @@ -90,6 +91,164 @@ class LDSMem: SMEM_LOAD = {SMEMOp.S_LOAD_B32: 1, SMEMOp.S_LOAD_B64: 2, SMEMOp.S_LOAD_B128: 4, SMEMOp.S_LOAD_B256: 8, SMEMOp.S_LOAD_B512: 16} +# ═══════════════════════════════════════════════════════════════════════════════ +# TIMING MODEL - Instruction latencies and register dependency tracking +# ═══════════════════════════════════════════════════════════════════════════════ + +class InstType(IntEnum): + """Instruction type for timing model.""" + SALU = 0 # Scalar ALU + VALU = 1 # Vector ALU (simple) + TRANS32 = 2 # Transcendental (v_rcp, v_rsq, v_sqrt, v_log, v_exp, v_sin, v_cos) + SMEM = 3 # Scalar memory + VMEM = 4 # Vector memory (global/scratch) + LDS = 5 # LDS operations + BRANCH = 6 # Branch/control flow + NOP = 7 # No operation + OTHER = 8 # Other + +# Latencies from s_delay_alu encoding (VALU_DEP_1-4, TRANS32_DEP_1-3, SALU_CYCLE_1-3) +INST_LATENCY = { + InstType.SALU: 1, # SALU_CYCLE_1 + InstType.VALU: 1, # VALU_DEP_1 (with forwarding, no stall) + InstType.TRANS32: 4, # TRANS32_DEP_3 (conservative) + InstType.SMEM: 4, # Variable, use small fixed value + InstType.VMEM: 1, # Variable, skip for now (issue immediately) + InstType.LDS: 1, # Variable, skip for now + InstType.BRANCH: 1, + InstType.NOP: 1, + InstType.OTHER: 1, +} + +# Transcendental ops - higher latency +_TRANS_OPS = {'V_RCP_F32', 'V_RCP_F64', 'V_RSQ_F32', 'V_RSQ_F64', 'V_SQRT_F32', 'V_SQRT_F64', + 'V_LOG_F32', 'V_EXP_F32', 'V_SIN_F32', 'V_COS_F32', 'V_RCP_F16', 'V_RSQ_F16', 'V_SQRT_F16'} + +def classify_inst(inst: Inst) -> InstType: + """Classify instruction for timing model.""" + if isinstance(inst, (SOP1, SOP2, SOPC, SOPK)): + return InstType.SALU + if isinstance(inst, SOPP): + if inst.op in (SOPPOp.S_ENDPGM, SOPPOp.S_BARRIER): return InstType.OTHER + if hasattr(inst, 'op') and 'BRANCH' in getattr(inst.op, 'name', ''): return InstType.BRANCH + return InstType.NOP + if isinstance(inst, SMEM): + return InstType.SMEM + if isinstance(inst, (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC)): + op_name = inst.op_name if hasattr(inst, 'op_name') else '' + if any(t in op_name for t in _TRANS_OPS): + return InstType.TRANS32 + return InstType.VALU + if isinstance(inst, VOPD): + return InstType.VALU + if isinstance(inst, FLAT): + return InstType.VMEM + if isinstance(inst, DS): + return InstType.LDS + return InstType.OTHER + +class TimingState: + """Timing state for cycle-accurate emulation with SQTT output.""" + __slots__ = ('cycle', 'sgpr_ready', 'vgpr_ready', 'packets', 'wave_id', 'simd', 'cu') + + def __init__(self, wave_id: int = 0, simd: int = 0, cu: int = 0): + self.cycle = 0 # Current cycle + self.sgpr_ready: dict[int, int] = {} # SGPR -> cycle when ready + self.vgpr_ready: dict[int, int] = {} # VGPR -> cycle when ready (all lanes same for now) + self.packets: list = [] # SQTT packets + self.wave_id, self.simd, self.cu = wave_id, simd, cu + + def stall_for_read(self, inst: Inst, st: 'WaveState') -> int: + """Calculate stall cycles needed before executing this instruction.""" + max_ready = self.cycle + + # Get source registers + srcs = [] + if isinstance(inst, (SOP1, SOP2, SOPC)): + if hasattr(inst, 'ssrc0') and inst.ssrc0 < SGPR_COUNT: srcs.append(('s', inst.ssrc0)) + if hasattr(inst, 'ssrc1') and inst.ssrc1 < SGPR_COUNT: srcs.append(('s', inst.ssrc1)) + elif isinstance(inst, SOPK): + if hasattr(inst, 'sdst') and inst.sdst < SGPR_COUNT: srcs.append(('s', inst.sdst)) + elif isinstance(inst, SMEM): + if hasattr(inst, 'sbase'): srcs.append(('s', inst.sbase * 2)); srcs.append(('s', inst.sbase * 2 + 1)) + if hasattr(inst, 'soffset') and inst.soffset < SGPR_COUNT: srcs.append(('s', inst.soffset)) + elif isinstance(inst, (VOP1, VOP2, VOP3, VOP3SD, VOPC)): + if hasattr(inst, 'src0'): + if inst.src0 < SGPR_COUNT: srcs.append(('s', inst.src0)) + elif inst.src0 >= 256: srcs.append(('v', inst.src0 - 256)) + if hasattr(inst, 'src1'): + if inst.src1 < SGPR_COUNT: srcs.append(('s', inst.src1)) + elif inst.src1 >= 256: srcs.append(('v', inst.src1 - 256)) + if hasattr(inst, 'src2'): + if inst.src2 < SGPR_COUNT: srcs.append(('s', inst.src2)) + elif inst.src2 >= 256: srcs.append(('v', inst.src2 - 256)) + if hasattr(inst, 'vsrc1'): srcs.append(('v', inst.vsrc1)) + + # Find max ready time across sources + for typ, reg in srcs: + if typ == 's': + ready = self.sgpr_ready.get(reg, 0) + else: + ready = self.vgpr_ready.get(reg, 0) + if ready > max_ready: + max_ready = ready + + return max(0, max_ready - self.cycle) + + def schedule_write(self, inst: Inst, latency: int): + """Record when destination registers will be ready.""" + ready_cycle = self.cycle + latency + + # Get destination registers + if isinstance(inst, (SOP1, SOP2, SOPK)): + if hasattr(inst, 'sdst') and inst.sdst < SGPR_COUNT: + self.sgpr_ready[inst.sdst] = ready_cycle + if inst.dst_regs() == 2: + self.sgpr_ready[inst.sdst + 1] = ready_cycle + elif isinstance(inst, SMEM): + if hasattr(inst, 'sdata'): + cnt = SMEM_LOAD.get(inst.op, 1) + for i in range(cnt): + self.sgpr_ready[inst.sdata + i] = ready_cycle + elif isinstance(inst, (VOP1, VOP2, VOP3, VOP3SD, VOPC, VOPD)): + if hasattr(inst, 'vdst'): + self.vgpr_ready[inst.vdst] = ready_cycle + if inst.dst_regs() == 2: + self.vgpr_ready[inst.vdst + 1] = ready_cycle + elif isinstance(inst, FLAT): + if 'LOAD' in inst.op.name and hasattr(inst, 'vdst'): + ndwords = _op_ndwords(inst.op.name) + for i in range(ndwords): + self.vgpr_ready[inst.vdst + i] = ready_cycle + + def emit_wavestart(self): + """Emit WAVESTART packet.""" + from extra.assembly.amd.sqtt import Packet, Op + fields = (self.wave_id & 0xF) << 8 | (self.simd & 0x3) << 12 + self.packets.append(Packet(opcode=Op.WAVESTART, delta=0, fields=fields, time=self.cycle)) + + def emit_inst(self, inst: Inst, pc: int, stall: int, dur: int): + """Emit INST packet for instruction dispatch.""" + from extra.assembly.amd.sqtt import Packet, Op + inst_type = classify_inst(inst) + # Fields: wave[3:0] at [11:8], inst_type at higher bits + fields = (self.wave_id & 0xF) << 8 | (inst_type & 0xFF) << 12 + delta = stall + dur + self.packets.append(Packet(opcode=Op.INST, delta=delta, fields=fields, time=self.cycle)) + + def emit_waveend(self): + """Emit WAVEEND packet.""" + from extra.assembly.amd.sqtt import Packet, Op + fields = (self.wave_id & 0xF) << 8 | (self.simd & 0x3) << 12 + self.packets.append(Packet(opcode=Op.WAVEEND, delta=0, fields=fields, time=self.cycle)) + + def to_sqtt_blob(self) -> bytes: + """Encode all packets to raw SQTT blob.""" + from extra.assembly.amd.sqtt import Packet, Op, encode + # Prepend LAYOUT_HEADER + header = Packet(opcode=Op.LAYOUT_HEADER, delta=0, fields=0x68900180, time=0) + return encode([header] + self.packets) + # VOPD op -> VOP3 op mapping (VOPD is dual-issue of VOP1/VOP2 ops, use VOP3 enums for pseudocode lookup) _VOPD_TO_VOP = { VOPDOp.V_DUAL_FMAC_F32: VOP3Op.V_FMAC_F32, VOPDOp.V_DUAL_FMAAK_F32: VOP2Op.V_FMAAK_F32, VOPDOp.V_DUAL_FMAMK_F32: VOP2Op.V_FMAMK_F32, diff --git a/extra/assembly/amd/sqtt.py b/extra/assembly/amd/sqtt.py new file mode 100644 index 0000000000..47c28aaddd --- /dev/null +++ b/extra/assembly/amd/sqtt.py @@ -0,0 +1,394 @@ +"""SQTT (SQ Thread Trace) packet encoder and decoder for AMD GPUs. + +This module provides encoding and decoding of raw SQTT byte streams. +The format is nibble-based with variable-width packets determined by a state machine. +Uses BitField infrastructure from dsl.py, similar to GPU instruction encoding. +""" +from __future__ import annotations +from enum import IntEnum +from typing import get_type_hints +from extra.assembly.amd.dsl import BitField, bits + +# ═══════════════════════════════════════════════════════════════════════════════ +# FIELD ENUMS +# ═══════════════════════════════════════════════════════════════════════════════ + +class MemSrc(IntEnum): + LDS = 0 + LDS_ALT = 1 + VMEM = 2 + VMEM_ALT = 3 + +class AluSrc(IntEnum): + NONE = 0 + SALU = 1 + VALU = 2 + VALU_ALT = 3 + +class InstOp(IntEnum): + SALU = 0x0 + SMEM = 0x1 + JUMP = 0x3 + NEXT = 0x4 + MESSAGE = 0x9 + VALU = 0xb + VALU_64 = 0xd + VALU_MAD64 = 0xe + VMEM_LOAD = 0x21 + VMEM_LOAD_ALT = 0x22 + VMEM_STORE = 0x24 + VMEM_STORE_64 = 0x25 + VMEM_STORE_96 = 0x26 + VMEM_STORE_128 = 0x27 + VMEM_STORE_ALT = 0x28 + LDS_LOAD = 0x29 + LDS_STORE = 0x2b + LDS_STORE_128 = 0x2e + SIMD_LDS_LOAD = 0x50 + SIMD_LDS_LOAD_ALT = 0x51 + SIMD_LDS_STORE = 0x54 + SIMD_VMEM_LOAD = 0x5a + SIMD_VMEM_LOAD_ALT = 0x5b + SIMD_VMEM_STORE = 0x5c + SIMD_VMEM_STORE_ALT = 0x5d + SIMD_VMEM_STORE_96 = 0x5e + SIMD_VMEM_STORE_128 = 0x5f + SALU_OR = 0x72 + VALU_CMPX = 0x73 + +# ═══════════════════════════════════════════════════════════════════════════════ +# PACKET TYPE BASE CLASS +# ═══════════════════════════════════════════════════════════════════════════════ + +class PacketType: + """Base class for SQTT packet types.""" + _encoding: tuple[BitField, int] | None = None + _field_types: dict[str, type] = {} + _values: dict[str, int] + _raw: int + _time: int + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if 'encoding' in cls.__dict__ and isinstance(cls.__dict__['encoding'], tuple): + cls._encoding = cls.__dict__['encoding'] + # Cache field type annotations for enum conversion + try: + cls._field_types = {k: v for k, v in get_type_hints(cls).items() + if isinstance(v, type) and issubclass(v, IntEnum)} + except Exception: + cls._field_types = {} + + @classmethod + def fields(cls) -> dict[str, BitField]: + return {k: v for k, v in cls.__dict__.items() if isinstance(v, BitField) and k != 'encoding'} + + @classmethod + def size_bits(cls) -> int: + max_bit = max((f.hi for f in cls.fields().values()), default=0) + return ((max_bit + 4) // 4) * 4 + + @classmethod + def size_nibbles(cls) -> int: + return cls.size_bits() // 4 + + @classmethod + def from_raw(cls, raw: int, time: int = 0): + inst = object.__new__(cls) + inst._raw = raw + inst._time = time + inst._values = {} + for name, bf in cls.fields().items(): + val = (raw >> bf.lo) & bf.mask() + # Convert to enum if annotated + enum_type = cls._field_types.get(name) + if enum_type is not None: + try: val = enum_type(val) + except ValueError: pass + inst._values[name] = val + return inst + + def __getattr__(self, name: str): + if name.startswith('_'): raise AttributeError(name) + return self._values.get(name, 0) + + def __repr__(self) -> str: + fields_str = ", ".join(f"{k}={v}" for k, v in self._values.items() if not k.startswith('_')) + return f"{self.__class__.__name__}({fields_str})" + +# ═══════════════════════════════════════════════════════════════════════════════ +# PACKET TYPE DEFINITIONS +# ═══════════════════════════════════════════════════════════════════════════════ + +class VALUINST(PacketType): + encoding = bits[2:0] == 0b011 + delta = bits[5:3] + flag = bits[6:6] + wave = bits[11:7] + +class VMEMEXEC(PacketType): + encoding = bits[3:0] == 0b1111 + delta = bits[5:4] + src: MemSrc = bits[7:6] + +class ALUEXEC(PacketType): + encoding = bits[3:0] == 0b1110 + delta = bits[5:4] + src: AluSrc = bits[7:6] + +class IMMEDIATE(PacketType): + encoding = bits[3:0] == 0b1101 + delta = bits[6:4] + wave = bits[11:7] + +class IMMEDIATE_MASK(PacketType): + encoding = bits[4:0] == 0b00100 + delta = bits[7:5] + mask = bits[23:8] + +class WAVERDY(PacketType): + encoding = bits[4:0] == 0b10100 + delta = bits[7:5] + mask = bits[23:8] + +class TS_DELTA_S8_W3(PacketType): + encoding = bits[6:0] == 0b0100001 + delta = bits[10:8] + _padding = bits[63:11] + +class WAVEEND(PacketType): + encoding = bits[4:0] == 0b10101 + delta = bits[7:5] + flag7 = bits[8:8] + simd = bits[10:9] + cu_lo = bits[13:11] + wave = bits[19:15] + @property + def cu(self) -> int: return self.cu_lo | (self.flag7 << 3) + +class WAVESTART(PacketType): + encoding = bits[4:0] == 0b01100 + delta = bits[6:5] + flag7 = bits[7:7] + simd = bits[9:8] + cu_lo = bits[12:10] + wave = bits[17:13] + id7 = bits[31:18] + @property + def cu(self) -> int: return self.cu_lo | (self.flag7 << 3) + +class TS_DELTA_S5_W2(PacketType): + encoding = bits[4:0] == 0b11100 + delta = bits[6:5] + _padding = bits[47:7] + +class WAVEALLOC(PacketType): + encoding = bits[4:0] == 0b00101 + delta = bits[7:5] + _padding = bits[19:8] + +class TS_DELTA_S5_W3(PacketType): + encoding = bits[4:0] == 0b00110 + delta = bits[7:5] + _padding = bits[51:8] + +class PERF(PacketType): + encoding = bits[4:0] == 0b10110 + delta = bits[7:5] + arg = bits[27:8] + +class TS_DELTA_SHORT(PacketType): + encoding = bits[3:0] == 0b1000 + delta = bits[7:4] + +class NOP(PacketType): + encoding = bits[3:0] == 0b0000 + delta = None # type: ignore + _padding = bits[3:0] + +class TS_WAVE_STATE(PacketType): + encoding = bits[6:0] == 0b1010001 + delta = bits[15:7] + coarse = bits[23:16] + @property + def wave_interest(self) -> bool: return bool(self.coarse & 1) + @property + def terminate_all(self) -> bool: return bool(self.coarse & 8) + +class EVENT(PacketType): + encoding = bits[7:0] == 0b01100001 + delta = bits[10:8] + event = bits[23:11] + +class EVENT_BIG(PacketType): + encoding = bits[7:0] == 0b11100001 + delta = bits[10:8] + event = bits[31:11] + +class REG(PacketType): + encoding = bits[3:0] == 0b1001 + delta = bits[6:4] + slot = bits[9:7] + hi_byte = bits[15:8] + subop = bits[31:16] + val32 = bits[63:32] + @property + def is_config(self) -> bool: return bool(self.hi_byte & 0x80) + +class SNAPSHOT(PacketType): + encoding = bits[6:0] == 0b1110001 + delta = bits[9:7] + snap = bits[63:10] + +class TS_DELTA_OR_MARK(PacketType): + encoding = bits[6:0] == 0b0000001 + delta = bits[47:12] + bit8 = bits[8:8] + bit9 = bits[9:9] + @property + def is_marker(self) -> bool: return bool(self.bit9 and not self.bit8) + +class LAYOUT_HEADER(PacketType): + encoding = bits[6:0] == 0b0010001 + delta = None # type: ignore + layout = bits[12:7] + simd = bits[14:13] + group = bits[17:15] + sel_a = bits[31:28] + sel_b = bits[36:33] + flag4 = bits[59:59] + _padding = bits[63:60] + +class INST(PacketType): + encoding = bits[2:0] == 0b010 + delta = bits[6:4] + flag1 = bits[3:3] + flag2 = bits[7:7] + wave = bits[12:8] + op: InstOp = bits[19:13] + +class UTILCTR(PacketType): + encoding = bits[6:0] == 0b0110001 + delta = bits[8:7] + ctr = bits[47:9] + +# All packet types in encoding priority order (more specific masks first, NOP last as fallback) +PACKET_TYPES: list[type[PacketType]] = [ + EVENT, EVENT_BIG, + TS_DELTA_S8_W3, TS_WAVE_STATE, SNAPSHOT, TS_DELTA_OR_MARK, LAYOUT_HEADER, UTILCTR, + IMMEDIATE_MASK, WAVERDY, WAVEEND, WAVESTART, TS_DELTA_S5_W2, WAVEALLOC, TS_DELTA_S5_W3, PERF, + VMEMEXEC, ALUEXEC, IMMEDIATE, TS_DELTA_SHORT, REG, + VALUINST, INST, + NOP, +] + +PACKET_BY_NAME: dict[str, type[PacketType]] = {cls.__name__: cls for cls in PACKET_TYPES} + +def _build_state_table() -> tuple[bytes, dict[int, type[PacketType]]]: + table = [len(PACKET_TYPES) - 1] * 256 # default to NOP + opcode_to_class: dict[int, type[PacketType]] = {i: cls for i, cls in enumerate(PACKET_TYPES)} + + for byte_val in range(256): + for opcode, pkt_cls in enumerate(PACKET_TYPES): + if pkt_cls._encoding is None: continue + mask_bf, pattern = pkt_cls._encoding + if (byte_val & mask_bf.mask()) == pattern: + table[byte_val] = opcode + break + + return bytes(table), opcode_to_class + +STATE_TO_OPCODE, OPCODE_TO_CLASS = _build_state_table() + +OPCODE_TO_BYTES: dict[int, list[int]] = {} +for _byte_val, _opcode in enumerate(STATE_TO_OPCODE): + if _opcode not in OPCODE_TO_BYTES: OPCODE_TO_BYTES[_opcode] = [] + OPCODE_TO_BYTES[_opcode].append(_byte_val) + +BUDGET = {opcode: pkt_cls.size_nibbles() for opcode, pkt_cls in OPCODE_TO_CLASS.items()} + +# ═══════════════════════════════════════════════════════════════════════════════ +# DECODER +# ═══════════════════════════════════════════════════════════════════════════════ + +def decode(data: bytes) -> list[PacketType]: + """Decode raw SQTT blob into list of packet instances.""" + packets: list[PacketType] = [] + n = len(data) + reg = 0 + offset = 0 + nib_count = 16 + time = 0 + + while (offset >> 3) < n: + target = offset + nib_count * 4 + while offset < target and (offset >> 3) < n: + byte = data[offset >> 3] + nib = (byte >> (offset & 4)) & 0xF + reg = ((reg >> 4) | (nib << 60)) & ((1 << 64) - 1) + offset += 4 + if offset < target: break + + opcode = STATE_TO_OPCODE[reg & 0xFF] + pkt_cls = OPCODE_TO_CLASS[opcode] + nib_count = BUDGET[opcode] + + delta_field = getattr(pkt_cls, 'delta', None) + delta = (reg >> delta_field.lo) & delta_field.mask() if delta_field is not None else 0 + + if pkt_cls is TS_DELTA_OR_MARK: + bit8 = (reg >> TS_DELTA_OR_MARK.bit8.lo) & TS_DELTA_OR_MARK.bit8.mask() + bit9 = (reg >> TS_DELTA_OR_MARK.bit9.lo) & TS_DELTA_OR_MARK.bit9.mask() + if bit9 and not bit8: delta = 0 + elif pkt_cls is TS_DELTA_SHORT: + delta = delta + 8 + + time += delta + pkt = pkt_cls.from_raw(reg, time) + packets.append(pkt) + + return packets + +# ═══════════════════════════════════════════════════════════════════════════════ +# ENCODER +# ═══════════════════════════════════════════════════════════════════════════════ + +def encode(packets: list[PacketType]) -> bytes: + """Encode a list of packet instances into raw SQTT blob.""" + if not packets: return b'' + + read_lengths = [16] + for p in packets[:-1]: + read_lengths.append(p.size_nibbles()) + + total_nibbles = sum(read_lengths) + bits_arr = [0] * (total_nibbles * 4) + + cumulative = 0 + for i, p in enumerate(packets): + cumulative += read_lengths[i] + pkt_cls = type(p) + opcode = next(op for op, cls in OPCODE_TO_CLASS.items() if cls is pkt_cls) + + byte_vals = OPCODE_TO_BYTES.get(opcode) + if not byte_vals: raise ValueError(f"No encoding for {pkt_cls.__name__}") + opcode_byte = byte_vals[0] + + delta_field = getattr(pkt_cls, 'delta', None) + if delta_field is not None and delta_field.hi < 8: + delta = p._values.get('delta', 0) + if isinstance(delta, IntEnum): delta = delta.value + if pkt_cls is TS_DELTA_SHORT: delta = max(0, delta - 8) + delta = delta & delta_field.mask() + opcode_byte = (opcode_byte & ~(delta_field.mask() << delta_field.lo)) | (delta << delta_field.lo) + + opcode_nibble_pos = max(0, cumulative - 16) + opcode_bit_pos = opcode_nibble_pos * 4 + + for b in range(8): + if opcode_bit_pos + b < len(bits_arr): + bits_arr[opcode_bit_pos + b] = (opcode_byte >> b) & 1 + + nibbles = [sum(bits_arr[i + j] << j for j in range(4) if i + j < len(bits_arr)) for i in range(0, len(bits_arr), 4)] + while len(nibbles) % 2: nibbles.append(0) + return bytes(nibbles[i] | (nibbles[i + 1] << 4) for i in range(0, len(nibbles), 2)) diff --git a/extra/assembly/amd/test/discover_instops.py b/extra/assembly/amd/test/discover_instops.py new file mode 100644 index 0000000000..b333ed82fd --- /dev/null +++ b/extra/assembly/amd/test/discover_instops.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +"""SQTT InstOp discovery tool - finds instruction opcodes by running different instructions. + +Run with: DEBUG=1 python extra/assembly/amd/test/discover_instops.py +For full traces: DEBUG=2 python extra/assembly/amd/test/discover_instops.py +""" +import os +os.environ["SQTT"] = "1" +os.environ["PROFILE"] = "1" +os.environ["SQTT_ITRACE_SE_MASK"] = "2" # Enable instruction tracing on SE1 +os.environ["SQTT_LIMIT_SE"] = "2" # Force work to traced SE only + +from tinygrad.helpers import DEBUG, colored +from tinygrad.runtime.ops_amd import SQTT_SIMD_SEL + +from extra.assembly.amd.autogen.rdna3.ins import ( + # VALU - basic (these are safe, just register ops) + v_mov_b32_e32, v_add_f32_e32, v_mul_f32_e32, + v_and_b32_e32, v_or_b32_e32, v_xor_b32_e32, + v_lshlrev_b32_e32, v_lshrrev_b32_e32, + # VALU - transcendental + v_exp_f32_e32, v_log_f32_e32, v_rcp_f32_e32, v_sqrt_f32_e32, + # VALU - 64-bit + v_lshlrev_b64, v_lshrrev_b64, + # VALU - compare (writes to VCC, safe) + v_cmp_eq_u32_e32, + v_cmpx_eq_u32_e32, + # SALU - basic (safe, just register ops) + s_mov_b32, s_add_u32, s_and_b32, s_or_b32, + s_lshl_b32, s_lshr_b32, + s_nop, s_endpgm, + # SALU - branch (safe if offset is 0 = next instruction) + s_branch, s_cbranch_scc0, s_cbranch_execz, + # SALU - message + s_sendmsg, +) +from extra.assembly.amd.dsl import v, s +from extra.assembly.amd.sqtt import InstOp, INST + +from extra.assembly.amd.test.test_sqtt_hw import ( + run_asm_sqtt, decode_all_blobs, get_inst_ops, print_blobs +) + +# ═══════════════════════════════════════════════════════════════════════════════ +# INSTRUCTION TEST CASES - only safe instructions that don't access memory +# ═══════════════════════════════════════════════════════════════════════════════ + +INSTRUCTION_TESTS: dict[str, tuple[str, list]] = { + # SALU (0x0) - scalar ALU, just register operations + "SALU_mov": ("s_mov_b32", [s_mov_b32(s[0], 0), s_mov_b32(s[1], 1)]), + "SALU_add": ("s_add_u32", [s_mov_b32(s[0], 1), s_mov_b32(s[1], 2), s_add_u32(s[2], s[0], s[1])]), + "SALU_logic": ("s_and/or", [s_and_b32(s[2], s[0], s[1]), s_or_b32(s[3], s[0], s[1])]), + "SALU_shift": ("s_lshl/lshr", [s_lshl_b32(s[2], s[0], 1), s_lshr_b32(s[3], s[0], 1)]), + "SALU_nop": ("s_nop", [s_nop(0)]), + + # JUMP (0x3) - branch to next instruction (offset 0) + "JUMP_branch": ("s_branch", [s_branch(0)]), + + # VALU (0xb) - vector ALU, just register operations + "VALU_mov": ("v_mov_b32", [v_mov_b32_e32(v[0], 0), v_mov_b32_e32(v[1], 1.0)]), + "VALU_add": ("v_add_f32", [v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[1], 2.0), v_add_f32_e32(v[2], v[0], v[1])]), + "VALU_mul": ("v_mul_f32", [v_mul_f32_e32(v[2], v[0], v[1])]), + "VALU_logic": ("v_and/or/xor", [v_and_b32_e32(v[2], v[0], v[1]), v_or_b32_e32(v[3], v[0], v[1]), v_xor_b32_e32(v[4], v[0], v[1])]), + "VALU_shift": ("v_lshl/lshr", [v_lshlrev_b32_e32(v[2], 1, v[0]), v_lshrrev_b32_e32(v[3], 1, v[0])]), + + # VALU transcendental - still just register ops + "VALU_exp": ("v_exp_f32", [v_mov_b32_e32(v[0], 1.0), v_exp_f32_e32(v[1], v[0])]), + "VALU_log": ("v_log_f32", [v_mov_b32_e32(v[0], 1.0), v_log_f32_e32(v[1], v[0])]), + "VALU_rcp": ("v_rcp_f32", [v_mov_b32_e32(v[0], 1.0), v_rcp_f32_e32(v[1], v[0])]), + "VALU_sqrt": ("v_sqrt_f32", [v_mov_b32_e32(v[0], 1.0), v_sqrt_f32_e32(v[1], v[0])]), + + # VALU 64-bit (0xd) + "VALU64_lshl": ("v_lshlrev_b64", [v_lshlrev_b64(v[0:1], 1, v[2:3])]), + + # VALU MAD64 (0xe) - commented out, needs proper clamp arg + # "VALU_mad64": ("v_mad_u64_u32", [v_mad_u64_u32(v[0:1], None, v[2], v[3], v[4:5])]), + + # VALU compare - writes to VCC + "VALU_cmp": ("v_cmp_eq_u32", [v_cmp_eq_u32_e32(v[0], v[1])]), + + # VALU CMPX (0x73) - modifies EXEC + "VALU_cmpx": ("v_cmpx_eq_u32", [v_cmpx_eq_u32_e32(v[0], v[1])]), +} + + +def run_with_simd_retry(instructions: list, max_retries: int = 4) -> tuple[list[bytes], list, set]: + """Run instructions and retry with different SIMD selections until we get INST packets.""" + for simd in range(max_retries): + SQTT_SIMD_SEL.value = simd + blobs = run_asm_sqtt(instructions) + packets = decode_all_blobs(blobs) + ops = get_inst_ops(packets) + if ops: + return blobs, packets, ops + # Return last attempt even if no ops found + return blobs, packets, ops + +def discover_all_instops() -> tuple[dict[int, set[str]], dict[str, Exception]]: + """Run all instruction tests and collect InstOp values.""" + discovered: dict[int, set[str]] = {} + failures: dict[str, Exception] = {} + + for test_name, (instr_name, instructions) in INSTRUCTION_TESTS.items(): + try: + blobs, packets, ops = run_with_simd_retry(instructions) + + for op in ops: + if op not in discovered: + discovered[op] = set() + discovered[op].add(f"{test_name}") + + if DEBUG >= 2: + print(f"\n{'─'*60}") + print(f"{test_name} ({instr_name}): ops={[hex(op) for op in sorted(ops)]} simd_sel={SQTT_SIMD_SEL.value}") + print_blobs(blobs, filter_timing=True) + if DEBUG >= 1: + status = colored("✓", "green") if ops else colored("∅", "yellow") + ops_str = ", ".join(hex(op) for op in sorted(ops)) if ops else "none" + print(f" {status} {test_name:25s} ops=[{ops_str}]") + + except Exception as e: + failures[test_name] = e + if DEBUG >= 1: + print(f" {colored('✗', 'red')} {test_name:25s} FAILED: {e}") + + return discovered, failures + + +def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception]) -> None: + """Print discovery summary.""" + known_ops = {e.value for e in InstOp} + discovered_ops = set(discovered.keys()) + + print("\n" + "=" * 60) + print("DISCOVERED INSTOP VALUES") + print("=" * 60) + + for op in sorted(discovered_ops): + try: + name = InstOp(op).name + status = colored("known", "green") + except ValueError: + name = f"UNKNOWN" + status = colored("NEW!", "yellow") + + sources = ", ".join(sorted(discovered[op])) + print(f" 0x{op:02x} {name:20s} ({status}) <- {sources}") + + # Missing from enum + missing = known_ops - discovered_ops + if missing: + print("\n" + "=" * 60) + print("ENUM VALUES NOT DISCOVERED") + print("=" * 60) + print("(need memory ops: SMEM, VMEM, LDS)") + for op in sorted(missing): + print(f" 0x{op:02x} {InstOp(op).name}") + + # New values to add + new_ops = discovered_ops - known_ops + if new_ops: + print("\n" + "=" * 60) + print(colored("NEW INSTOP VALUES TO ADD TO ENUM", "yellow")) + print("=" * 60) + for op in sorted(new_ops): + sources = ", ".join(sorted(discovered[op])) + print(f" {op:#04x}: \"{sources}\",") + + # Stats + print("\n" + "=" * 60) + print("STATISTICS") + print("=" * 60) + print(f" Tests run: {len(INSTRUCTION_TESTS)}") + print(f" Tests passed: {len(INSTRUCTION_TESTS) - len(failures)}") + print(f" Tests failed: {len(failures)}") + print(f" Known ops: {len(known_ops)}") + print(f" Discovered: {len(discovered_ops)}") + if known_ops: + print(f" Coverage: {len(discovered_ops & known_ops)}/{len(known_ops)} ({100*len(discovered_ops & known_ops)//len(known_ops)}%)") + print(f" New ops found: {len(new_ops)}") + + +if __name__ == "__main__": + print("=" * 60) + print("SQTT InstOp Discovery Tool") + print("=" * 60) + print(f"Testing {len(INSTRUCTION_TESTS)} instruction categories...\n") + + discovered, failures = discover_all_instops() + print_summary(discovered, failures) diff --git a/extra/assembly/amd/test/test_sqtt.py b/extra/assembly/amd/test/test_sqtt.py new file mode 100644 index 0000000000..f7c1a1b7ca --- /dev/null +++ b/extra/assembly/amd/test/test_sqtt.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +"""Tests for SQTT packet codec (no hardware required).""" +import unittest +from extra.assembly.amd.sqtt import ( + LAYOUT_HEADER, WAVESTART, WAVEEND, INST, NOP, + decode, encode, PACKET_TYPES, OPCODE_TO_CLASS +) + + +class TestSQTTCodec(unittest.TestCase): + """Tests for SQTT encoder/decoder roundtrip.""" + + def test_roundtrip_simple(self): + """Test encode/decode roundtrip for simple packets.""" + test_packets = [ + LAYOUT_HEADER.from_raw(0x100), + WAVESTART.from_raw(0x0), + INST.from_raw(0x10), # delta=1 + INST.from_raw(0x10), # delta=1 + WAVEEND.from_raw(0x40), # delta=2 + ] + encoded = encode(test_packets) + decoded = decode(encoded) + + self.assertGreaterEqual(len(decoded), len(test_packets)) + for i, (orig, dec) in enumerate(zip(test_packets, decoded)): + self.assertEqual(type(orig), type(dec), f"type mismatch at {i}") + + def test_decode_empty(self): + """Test decoding empty data.""" + packets = decode(b'') + self.assertEqual(packets, []) + + def test_encode_empty(self): + """Test encoding empty list.""" + data = encode([]) + self.assertEqual(data, b'') + + def test_all_packet_types_have_encoding(self): + """All packet types should have an encoding defined.""" + for pkt_cls in PACKET_TYPES: + self.assertIsNotNone(pkt_cls._encoding, f"{pkt_cls.__name__} missing encoding") + + def test_packet_from_raw(self): + """Test creating packets from raw values.""" + # INST with wave=5, op=0x21, delta=2 + raw = (0x21 << 13) | (5 << 8) | (2 << 4) | 0b010 + pkt = INST.from_raw(raw) + self.assertEqual(pkt.wave, 5) + self.assertEqual(pkt.op, 0x21) + self.assertEqual(pkt.delta, 2) + + +class TestDecodeRealBlob(unittest.TestCase): + """Test decoding real SQTT blobs from examples.""" + + def test_decode_example_file(self): + """Test decoding a real SQTT blob from examples.""" + import pickle + from pathlib import Path + example_path = Path(__file__).parent.parent.parent.parent / "sqtt/examples/profile_plus_run_0.pkl" + if not example_path.exists(): + self.skipTest(f"Example file not found: {example_path}") + + from tinygrad.runtime.ops_amd import ProfileSQTTEvent + with open(example_path, "rb") as f: + data = pickle.load(f) + + sqtt_events = [e for e in data if isinstance(e, ProfileSQTTEvent)] + self.assertGreater(len(sqtt_events), 0, "No SQTT events in example") + + packets = decode(sqtt_events[0].blob) + self.assertGreater(len(packets), 0, "No packets decoded") + # First packet should be LAYOUT_HEADER + self.assertIsInstance(packets[0], LAYOUT_HEADER) + + +if __name__ == "__main__": + unittest.main() diff --git a/extra/assembly/amd/test/test_sqtt_compare.py b/extra/assembly/amd/test/test_sqtt_compare.py new file mode 100644 index 0000000000..dcdf2cea3b --- /dev/null +++ b/extra/assembly/amd/test/test_sqtt_compare.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 +"""Tests comparing hardware SQTT traces to emulator SQTT output. + +Run with: python -m pytest extra/assembly/amd/test/test_sqtt_compare.py -v +Requires AMD GPU with SQTT support. +""" +import os +os.environ["SQTT"] = "1" +os.environ["PROFILE"] = "1" +os.environ["AMD_LLVM"] = "0" + +import unittest, sys, contextlib +from tinygrad import Tensor +from tinygrad.device import Device +from tinygrad.uop.ops import UOp, Ops, KernelInfo +from tinygrad.runtime.ops_amd import ProfileSQTTEvent + +from extra.assembly.amd.sqtt import decode, encode, LAYOUT_HEADER, WAVESTART, WAVEEND, INST, PacketType + +# ═══════════════════════════════════════════════════════════════════════════════ +# HARDWARE SQTT CAPTURE +# ═══════════════════════════════════════════════════════════════════════════════ + +dev = Device["AMD"] + +def custom(arg: str, s: UOp | None = None) -> UOp: + return UOp(Ops.CUSTOM, src=(s,) if s is not None else (), arg=arg) + +def asm_kernel(instrs: list[str], local_size: int = 1, global_size: int = 1) -> Tensor: + """Create a kernel from inline assembly instructions.""" + name = sys._getframe(1).f_code.co_name + def fxn(_): + L = UOp.special(local_size, "lidx0") + G = UOp.special(global_size, "gidx0") + op = custom("asm volatile (") + for inst in instrs: + op = custom(f' "{inst}\\n\\t"', op) + op = custom(");", op) + return UOp.sink(op, L, G, arg=KernelInfo(name=name)) + return Tensor.custom_kernel(Tensor.empty(1), fxn=fxn)[0] + +@contextlib.contextmanager +def capture_hw_sqtt(): + """Capture raw SQTT blobs from hardware execution.""" + dev.profile_events.clear() + result: dict[str, list[bytes]] = {} + yield result + for ev in dev.profile_events: + if isinstance(ev, ProfileSQTTEvent): + result.setdefault(ev.kern, []).append(ev.blob) + +# ═══════════════════════════════════════════════════════════════════════════════ +# TESTS +# ═══════════════════════════════════════════════════════════════════════════════ + +@unittest.skipIf(not hasattr(dev, 'profile_events'), "AMD device required") +class TestHardwareSQTT(unittest.TestCase): + """Tests that verify hardware SQTT capture and parsing.""" + + def test_capture_and_parse(self): + """Verify we can capture and parse SQTT for simple VALU instructions.""" + with capture_hw_sqtt() as sqtt: + asm_kernel(["v_add_f32 v10, v10, v11", "v_add_f32 v11, v11, v12", "v_add_f32 v12, v12, v13"]).realize() + self.assertGreater(len(sqtt), 0, "No SQTT data captured") + kern = list(sqtt.keys())[0] + blob = sqtt[kern][0] + packets = decode(blob) + self.assertGreater(len(packets), 0, "No packets parsed") + pkt_types = {type(p) for p in packets} + # LAYOUT_HEADER should always be present + self.assertIn(LAYOUT_HEADER, pkt_types, "Missing LAYOUT_HEADER packet") + + def test_print_raw_packets(self): + """Debug test to print raw SQTT packets.""" + with capture_hw_sqtt() as sqtt: + asm_kernel([ + "v_mov_b32 v1, 1.0", + "v_add_f32 v2, v1, v1", + "s_nop 0", + "v_mul_f32 v3, v2, v2", + ]).realize() + kern = list(sqtt.keys())[0] + blob = sqtt[kern][0] + packets = decode(blob) + print(f"\n=== Raw SQTT packets for {kern} ({len(blob)} bytes, {len(packets)} packets) ===") + for i, p in enumerate(packets): + delta = p.delta if p.delta is not None else 0 + print(f" [{i:3d}] time={p._time:6d} delta={delta:4d} {type(p).__name__:18s} raw=0x{p._raw:x}") + + def test_valu_timing(self): + """Check VALU instruction timing from raw packets.""" + with capture_hw_sqtt() as sqtt: + asm_kernel([ + "v_add_f32 v10, v10, v11", + "v_add_f32 v11, v11, v12", + "v_add_f32 v12, v12, v13", + ]).realize() + kern = list(sqtt.keys())[0] + packets = decode(sqtt[kern][0]) + inst_packets = [p for p in packets if isinstance(p, INST)] + print(f"\n=== INST packets ===") + for i, p in enumerate(inst_packets): + print(f" [{i}] time={p._time} delta={p.delta} raw=0x{p._raw:x}") + + +class TestSQTTCodec(unittest.TestCase): + """Tests for SQTT encoder/decoder roundtrip.""" + + def test_roundtrip_simple(self): + """Test encode/decode roundtrip for simple packets.""" + test_packets = [ + LAYOUT_HEADER.from_raw(0x100), + WAVESTART.from_raw(0x0), + INST.from_raw(0x10), # delta=1 + INST.from_raw(0x10), # delta=1 + WAVEEND.from_raw(0x40), # delta=2 + ] + encoded = encode(test_packets) + decoded = decode(encoded) + + self.assertGreaterEqual(len(decoded), len(test_packets)) + for i, (orig, dec) in enumerate(zip(test_packets, decoded)): + self.assertEqual(type(orig), type(dec), f"type mismatch at {i}: {orig} vs {dec}") + + def test_decode_real_blob(self): + """Test decoding a real SQTT blob from examples.""" + import pickle + from pathlib import Path + example_path = Path(__file__).parent.parent.parent.parent / "sqtt/examples/profile_plus_run_0.pkl" + if not example_path.exists(): + self.skipTest(f"Example file not found: {example_path}") + + with open(example_path, "rb") as f: + data = pickle.load(f) + + sqtt_events = [e for e in data if isinstance(e, ProfileSQTTEvent)] + self.assertGreater(len(sqtt_events), 0, "No SQTT events in example") + + packets = decode(sqtt_events[0].blob) + self.assertGreater(len(packets), 0, "No packets decoded") + # Should see common packet types + pkt_types = {type(p) for p in packets} + self.assertIn(LAYOUT_HEADER, pkt_types) + + +@unittest.skip("Emulator SQTT not yet implemented") +class TestEmulatorSQTT(unittest.TestCase): + """Tests comparing emulator SQTT to hardware SQTT.""" + + def test_simple_valu_match(self): + """Simple VALU chain should produce matching SQTT packets.""" + pass + + +if __name__ == "__main__": + unittest.main() diff --git a/extra/assembly/amd/test/test_sqtt_hw.py b/extra/assembly/amd/test/test_sqtt_hw.py new file mode 100644 index 0000000000..5712f3f2e6 --- /dev/null +++ b/extra/assembly/amd/test/test_sqtt_hw.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python3 +"""Hardware tests for SQTT decoder - validates decoding of real SQTT streams. + +Run with: python -m pytest extra/assembly/amd/test/test_sqtt_hw.py -v -s +Requires AMD GPU with SQTT support. + +For pretty trace output: DEBUG=2 python -m pytest extra/assembly/amd/test/test_sqtt_hw.py -v -s +""" +import os +os.environ["SQTT"] = "1" +os.environ["PROFILE"] = "1" +os.environ["SQTT_ITRACE_SE_MASK"] = "2" # Enable instruction tracing on SE1 +os.environ["SQTT_LIMIT_SE"] = "1" # Limit execution + +import unittest +from tinygrad.helpers import DEBUG, colored +from tinygrad.device import Device +from tinygrad.runtime.ops_amd import AMDProgram, ProfileSQTTEvent +from tinygrad.runtime.support.compiler_amd import HIPCompiler + +from extra.assembly.amd.autogen.rdna3.ins import v_mov_b32_e32, v_add_f32_e32, v_mul_f32_e32, s_mov_b32, s_add_u32, s_nop, s_endpgm +from extra.assembly.amd.dsl import v, s +from extra.assembly.amd.sqtt import decode, LAYOUT_HEADER, WAVESTART, WAVEEND, INST, VALUINST, ALUEXEC, VMEMEXEC, InstOp, AluSrc, MemSrc + +dev = Device["AMD"] + +# ═══════════════════════════════════════════════════════════════════════════════ +# PRETTY PRINTING +# ═══════════════════════════════════════════════════════════════════════════════ + +PACKET_COLORS = { + "INST": "WHITE", "VALUINST": "WHITE", + "ALUEXEC": "yellow", "VMEMEXEC": "yellow", + "WAVESTART": "blue", "WAVEEND": "blue", "WAVEALLOC": "cyan", "WAVERDY": "cyan", + "LAYOUT_HEADER": "magenta", + "NOP": "BLACK", "TS_DELTA_SHORT": "BLACK", "TS_WAVE_STATE": "BLACK", + "TS_DELTA_OR_MARK": "BLACK", "TS_DELTA_S5_W2": "BLACK", "TS_DELTA_S5_W3": "BLACK", "TS_DELTA_S8_W3": "BLACK", + "REG": "green", "EVENT": "red", "EVENT_BIG": "red", "SNAPSHOT": "green", "UTILCTR": "green", "PERF": "green", + "IMMEDIATE": "cyan", "IMMEDIATE_MASK": "cyan", +} + +def format_packet(p, last_time: int = 0) -> str: + """Format a packet for pretty printing.""" + name = type(p).__name__ + color = PACKET_COLORS.get(name, "white") + delta = p._time - last_time + + fields = [] + if isinstance(p, INST): + op = p.op + op_name = op.name if isinstance(op, InstOp) else f"0x{op:02x}" + fields = [f"wave={p.wave}", f"op={op_name}"] + if p.flag1: fields.append("flag1") + if p.flag2: fields.append("flag2") + elif isinstance(p, VALUINST): + fields = [f"wave={p.wave}"] + if p.flag: fields.append("flag") + elif isinstance(p, ALUEXEC): + src_name = p.src.name if isinstance(p.src, AluSrc) else f"{p.src}" + fields = [f"src={src_name}"] + elif isinstance(p, VMEMEXEC): + src_name = p.src.name if isinstance(p.src, MemSrc) else f"{p.src}" + fields = [f"src={src_name}"] + elif isinstance(p, WAVESTART): + fields = [f"wave={p.wave}", f"simd={p.simd}", f"cu={p.cu}"] + elif isinstance(p, WAVEEND): + fields = [f"wave={p.wave}", f"simd={p.simd}", f"cu={p.cu}"] + elif hasattr(p, '_values'): + fields = [f"{k}={v}" for k, v in p._values.items() if not k.startswith('_') and k != 'delta'] + + return f"{p._time:8d} +{delta:6d} : " + colored(f"{name:18s}", color) + f" {', '.join(fields)}" + +def print_trace(packets: list, filter_timing: bool = True) -> None: + """Print a pretty trace of a single blob's packets.""" + last_time = 0 + skip_types = {"NOP", "TS_DELTA_SHORT", "TS_WAVE_STATE", "TS_DELTA_OR_MARK", "TS_DELTA_S5_W2", "TS_DELTA_S5_W3", "TS_DELTA_S8_W3", "REG", "UTILCTR"} + for p in packets: + name = type(p).__name__ + if filter_timing and name in skip_types: + last_time = p._time + continue + print(format_packet(p, last_time)) + last_time = p._time + +def print_blobs(blobs: list[bytes], filter_timing: bool = True) -> None: + """Print traces for all blobs.""" + for i, blob in enumerate(blobs): + packets = decode(blob) + print(f"\n--- Blob {i}: {len(blob)} bytes, {len(packets)} packets ---") + print_trace(packets, filter_timing) + +# ═══════════════════════════════════════════════════════════════════════════════ +# ASSEMBLY HELPERS +# ═══════════════════════════════════════════════════════════════════════════════ + +def assemble(instructions: list) -> bytes: + return b''.join(inst.to_bytes() for inst in instructions) + +def run_asm_sqtt(instructions: list, n_lanes: int = 1) -> list[bytes]: + """Run instructions on AMD hardware and return SQTT blobs.""" + compiler = HIPCompiler(dev.arch) + instructions = instructions + [s_endpgm()] + code = assemble(instructions) + + byte_str = ', '.join(f'0x{b:02x}' for b in code) + asm_src = f""".text +.globl test +.p2align 8 +.type test,@function +test: +.byte {byte_str} + +.rodata +.p2align 6 +.amdhsa_kernel test + .amdhsa_next_free_vgpr 256 + .amdhsa_next_free_sgpr 96 + .amdhsa_wavefront_size32 1 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_kernarg_size 8 + .amdhsa_group_segment_fixed_size 65536 +.end_amdhsa_kernel + +.amdgpu_metadata +--- +amdhsa.version: + - 1 + - 0 +amdhsa.kernels: + - .name: test + .symbol: test.kd + .kernarg_segment_size: 8 + .group_segment_fixed_size: 65536 + .private_segment_fixed_size: 0 + .kernarg_segment_align: 8 + .wavefront_size: 32 + .sgpr_count: 96 + .vgpr_count: 256 + .max_flat_workgroup_size: 1024 +... +.end_amdgpu_metadata +""" + + lib = compiler.compile(asm_src) + prg = AMDProgram(dev, "test", lib) + out_gpu = dev.allocator.alloc(2048) + dev.profile_events.clear() + prg(out_gpu, global_size=(1, 1, 1), local_size=(n_lanes, 1, 1), wait=True) + return [ev.blob for ev in dev.profile_events if isinstance(ev, ProfileSQTTEvent)] + +def decode_all_blobs(blobs: list[bytes]) -> list: + """Decode all blobs and combine packets.""" + all_packets = [] + for blob in blobs: + all_packets.extend(decode(blob)) + return all_packets + +def get_inst_ops(packets: list) -> set: + """Extract all InstOp values from INST packets.""" + ops = set() + for p in packets: + if isinstance(p, INST): + ops.add(p.op if isinstance(p.op, int) else p.op.value) + return ops + +# ═══════════════════════════════════════════════════════════════════════════════ +# TESTS +# ═══════════════════════════════════════════════════════════════════════════════ + +@unittest.skipIf(not hasattr(dev, 'profile_events'), "AMD device required") +class TestSQTTDecode(unittest.TestCase): + """Test SQTT decoder with real hardware traces.""" + + def test_basic_structure(self): + """Verify basic SQTT stream structure: LAYOUT_HEADER, WAVESTART, instructions, WAVEEND.""" + blobs = run_asm_sqtt([v_mov_b32_e32(v[0], 0)]) + + self.assertGreater(len(blobs), 0, "No SQTT data captured") + packets = decode_all_blobs(blobs) + + self.assertGreater(len(packets), 0, "No packets decoded") + self.assertGreater(len([p for p in packets if isinstance(p, LAYOUT_HEADER)]), 0, "No LAYOUT_HEADER packets") + self.assertGreater(len([p for p in packets if isinstance(p, WAVESTART)]), 0, "No WAVESTART packets") + self.assertGreater(len([p for p in packets if isinstance(p, WAVEEND)]), 0, "No WAVEEND packets") + + if DEBUG >= 2: + print("\n=== Basic structure trace ===") + print_trace(packets) + + def test_valu_instructions(self): + """Verify VALU instructions produce INST or VALUINST packets.""" + instructions = [ + v_mov_b32_e32(v[0], 1.0), + v_mov_b32_e32(v[1], 2.0), + v_add_f32_e32(v[2], v[0], v[1]), + v_add_f32_e32(v[3], v[2], v[1]), + v_mul_f32_e32(v[4], v[2], v[3]), + ] + blobs = run_asm_sqtt(instructions) + + self.assertGreater(len(blobs), 0, "No SQTT data captured") + packets = decode_all_blobs(blobs) + + inst_packets = [p for p in packets if isinstance(p, (INST, VALUINST))] + self.assertGreater(len(inst_packets), 0, "No INST/VALUINST packets for VALU instructions") + + if DEBUG >= 2: + print("\n=== VALU instructions trace ===") + print_trace(packets) + + def test_salu_instructions(self): + """Verify SALU instructions produce appropriate packets.""" + instructions = [ + s_mov_b32(s[0], 0), + s_mov_b32(s[1], 1), + s_add_u32(s[2], s[0], s[1]), + s_add_u32(s[3], s[2], s[1]), + s_nop(0), + ] + blobs = run_asm_sqtt(instructions) + + self.assertGreater(len(blobs), 0, "No SQTT data captured") + packets = decode_all_blobs(blobs) + + if DEBUG >= 2: + print("\n=== SALU instructions trace ===") + print_trace(packets) + + def test_timing_increases(self): + """Verify time increases monotonically through packets within each blob.""" + instructions = [ + v_mov_b32_e32(v[0], 1.0), + v_mov_b32_e32(v[1], 2.0), + v_add_f32_e32(v[2], v[0], v[1]), + v_mul_f32_e32(v[3], v[2], v[1]), + ] + blobs = run_asm_sqtt(instructions) + + self.assertGreater(len(blobs), 0, "No SQTT data captured") + for blob in blobs: + packets = decode(blob) + prev_time = 0 + for p in packets: + self.assertGreaterEqual(p._time, prev_time, f"Time decreased: {prev_time} -> {p._time}") + prev_time = p._time + + def test_wave_id_consistency(self): + """Verify wave IDs are consistent between WAVESTART/WAVEEND.""" + blobs = run_asm_sqtt([v_mov_b32_e32(v[0], 0)]) + + self.assertGreater(len(blobs), 0, "No SQTT data captured") + packets = decode_all_blobs(blobs) + + wavestarts = [p for p in packets if isinstance(p, WAVESTART)] + waveends = [p for p in packets if isinstance(p, WAVEEND)] + + if wavestarts and waveends: + start_waves = {p.wave for p in wavestarts} + end_waves = {p.wave for p in waveends} + self.assertTrue(start_waves & end_waves, "No matching wave IDs between WAVESTART and WAVEEND") + + def test_nop_sequence(self): + """Test a sequence of NOP instructions.""" + blobs = run_asm_sqtt([s_nop(0), s_nop(0), s_nop(0)]) + + self.assertGreater(len(blobs), 0, "No SQTT data captured") + packets = decode_all_blobs(blobs) + self.assertGreater(len(packets), 0, "No packets decoded") + + if DEBUG >= 2: + print("\n=== NOP sequence trace ===") + print_trace(packets, filter_timing=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/extra/assembly/amd/test/test_sqtt_ops.py b/extra/assembly/amd/test/test_sqtt_ops.py new file mode 100644 index 0000000000..fa33cbed9c --- /dev/null +++ b/extra/assembly/amd/test/test_sqtt_ops.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +"""Tests validating SQTT packet definitions against the reference implementation. + +Verifies that: +1. Encoding patterns produce the correct STATE_TO_OPCODE table +2. Packet sizes (derived from fields) match expected budget values +3. Field extractions match attempt_sqtt_parse.py +""" +import unittest +from extra.assembly.amd.sqtt import ( + VALUINST, VMEMEXEC, ALUEXEC, IMMEDIATE, IMMEDIATE_MASK, WAVERDY, + WAVEEND, WAVESTART, PERF, TS_WAVE_STATE, EVENT, EVENT_BIG, REG, SNAPSHOT, + TS_DELTA_OR_MARK, LAYOUT_HEADER, INST, UTILCTR, TS_DELTA_SHORT, NOP, + TS_DELTA_S8_W3, TS_DELTA_S5_W2, TS_DELTA_S5_W3, WAVEALLOC, + decode, encode, OPCODE_TO_CLASS, STATE_TO_OPCODE, PACKET_TYPES, BUDGET, + AluSrc, MemSrc, InstOp +) + +# Reference table from rocprof trace decoder (attempt_sqtt_parse.py) +REFERENCE_STATE_TABLE = bytes([ + 0x10, 0x16, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02, + 0x10, 0x17, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02, + 0x10, 0x07, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02, + 0x10, 0x19, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02, + 0x10, 0x00, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02, + 0x10, 0x11, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02, + 0x10, 0x12, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02, + 0x10, 0x15, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02, + 0x10, 0x16, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02, + 0x10, 0x17, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02, + 0x10, 0x07, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02, + 0x10, 0x19, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02, + 0x10, 0x00, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02, + 0x10, 0x11, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02, + 0x10, 0x13, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02, + 0x10, 0x15, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02, +]) + +# Reference opcode -> name mapping (old opcode values from rocprof) +OLD_OPCODE_TO_NAME = { + 0x01: 'VALUINST', 0x02: 'VMEMEXEC', 0x03: 'ALUEXEC', 0x04: 'IMMEDIATE', + 0x05: 'IMMEDIATE_MASK', 0x06: 'WAVERDY', 0x07: 'TS_DELTA_S8_W3', + 0x08: 'WAVEEND', 0x09: 'WAVESTART', 0x0A: 'TS_DELTA_S5_W2', + 0x0B: 'WAVEALLOC', 0x0C: 'TS_DELTA_S5_W3', 0x0D: 'PERF', + 0x0F: 'TS_DELTA_SHORT', 0x10: 'NOP', 0x11: 'TS_WAVE_STATE', + 0x12: 'EVENT', 0x13: 'EVENT_BIG', 0x14: 'REG', 0x15: 'SNAPSHOT', + 0x16: 'TS_DELTA_OR_MARK', 0x17: 'LAYOUT_HEADER', 0x18: 'INST', + 0x19: 'UTILCTR', 0x00: 'NOP', +} + +# Reference budget values (nibbles for NEXT packet) from rocprof +REFERENCE_BUDGET_NIBBLES = { + 'VALUINST': 3, 'VMEMEXEC': 2, 'ALUEXEC': 2, 'IMMEDIATE': 3, + 'IMMEDIATE_MASK': 6, 'WAVERDY': 6, 'TS_DELTA_S8_W3': 16, + 'WAVEEND': 5, 'WAVESTART': 8, 'TS_DELTA_S5_W2': 12, + 'WAVEALLOC': 5, 'TS_DELTA_S5_W3': 13, 'PERF': 7, + 'TS_DELTA_SHORT': 2, 'NOP': 1, 'TS_WAVE_STATE': 6, + 'EVENT': 6, 'EVENT_BIG': 8, 'REG': 16, 'SNAPSHOT': 16, + 'TS_DELTA_OR_MARK': 12, 'LAYOUT_HEADER': 16, 'INST': 5, + 'UTILCTR': 12, +} + + +class TestEncodingsMatchStateTable(unittest.TestCase): + """Verify encoding patterns produce the correct state decode table.""" + + def test_all_256_bytes_decode_correctly(self): + """Each byte value should decode to the same packet type as reference.""" + mismatches = [] + for byte_val in range(256): + ref_opcode = REFERENCE_STATE_TABLE[byte_val] + ref_name = OLD_OPCODE_TO_NAME.get(ref_opcode, f"UNK_{ref_opcode:02x}") + + our_opcode = STATE_TO_OPCODE[byte_val] + our_name = OPCODE_TO_CLASS[our_opcode].__name__ + + if ref_name != our_name: + mismatches.append((byte_val, ref_name, our_name)) + + if mismatches: + msg = "\n".join(f" 0x{b:02x}: expected {r}, got {o}" for b, r, o in mismatches[:10]) + self.fail(f"State table mismatches ({len(mismatches)} total):\n{msg}") + + +class TestPacketSizesMatchBudget(unittest.TestCase): + """Verify packet sizes (from field definitions) match expected budget values.""" + + def test_all_packet_sizes(self): + """Each packet type's size should match the reference budget.""" + for pkt_cls in PACKET_TYPES: + name = pkt_cls.__name__ + expected = REFERENCE_BUDGET_NIBBLES.get(name) + if expected is None: + continue + + actual = pkt_cls.size_nibbles() + self.assertEqual(expected, actual, + f"{name}: expected {expected} nibbles, got {actual} (size_bits={pkt_cls.size_bits()})") + + +class TestFieldExtraction(unittest.TestCase): + """Test that field values are extracted correctly.""" + + def test_valuinst(self): + reg = 0b11110_1_001_011 # wave=0x1E, flag=1, delta=1 + pkt = VALUINST.from_raw(reg) + self.assertEqual(pkt.delta, 1) + self.assertEqual(pkt.flag, 1) + self.assertEqual(pkt.wave, 0x1E) + + def test_vmemexec_enum(self): + reg = 0b11_00_1111 # src=3 (VMEM_ALT), delta=0 + pkt = VMEMEXEC.from_raw(reg) + self.assertEqual(pkt.src, MemSrc.VMEM_ALT) + + def test_aluexec_enum(self): + reg = 0b10_01_1110 # src=2 (VALU), delta=1 + pkt = ALUEXEC.from_raw(reg) + self.assertEqual(pkt.src, AluSrc.VALU) + + def test_waveend(self): + reg = (0x15 << 15) | (0x7 << 11) | (0x3 << 9) | (1 << 8) | 0b10101 + pkt = WAVEEND.from_raw(reg) + self.assertEqual(pkt.flag7, 1) + self.assertEqual(pkt.simd, 3) + self.assertEqual(pkt.cu_lo, 7) + self.assertEqual(pkt.wave, 0x15) + self.assertEqual(pkt.cu, 0xF) # cu_lo | (flag7 << 3) = 7 | 8 = 15 + + def test_wavestart(self): + reg = (0x7F << 18) | (0x15 << 13) | (0x7 << 10) | (0x3 << 8) | (1 << 7) | 0b01100 + pkt = WAVESTART.from_raw(reg) + self.assertEqual(pkt.flag7, 1) + self.assertEqual(pkt.simd, 3) + self.assertEqual(pkt.cu_lo, 7) + self.assertEqual(pkt.wave, 0x15) + self.assertEqual(pkt.id7, 0x7F) + self.assertEqual(pkt.cu, 0xF) + + def test_inst_enum(self): + reg = (0x21 << 13) | (0x15 << 8) | (1 << 7) | (1 << 3) | 0b010 + pkt = INST.from_raw(reg) + self.assertEqual(pkt.flag1, 1) + self.assertEqual(pkt.flag2, 1) + self.assertEqual(pkt.wave, 0x15) + self.assertEqual(pkt.op, InstOp.VMEM_LOAD) + + def test_layout_header(self): + reg = (0b101 << 33) | (0b1010 << 28) | (0b111 << 15) | (0b11 << 13) | (0b101010 << 7) | 0b0010001 + pkt = LAYOUT_HEADER.from_raw(reg) + self.assertEqual(pkt.layout, 0b101010) + self.assertEqual(pkt.simd, 0b11) + self.assertEqual(pkt.group, 0b111) + self.assertEqual(pkt.sel_a, 0b1010) + self.assertEqual(pkt.sel_b, 0b101) + + def test_ts_delta_or_mark_modes(self): + # delta mode: bit9=0, bit8=0 + pkt_delta = TS_DELTA_OR_MARK.from_raw(0b0000001) # just the encoding pattern + self.assertFalse(pkt_delta.is_marker) + + # marker mode: bit9=1, bit8=0 + pkt_marker = TS_DELTA_OR_MARK.from_raw(0b0000001 | (1 << 9)) # bit9=1, bit8=0 + self.assertTrue(pkt_marker.is_marker) + + # other mode: bit9=1, bit8=1 (not marker) + pkt_other = TS_DELTA_OR_MARK.from_raw(0b0000001 | (1 << 8) | (1 << 9)) + self.assertFalse(pkt_other.is_marker) + + def test_reg(self): + # REG fields: slot=bits[9:7], hi_byte=bits[15:8], subop=bits[31:16], val32=bits[63:32] + # Note: slot[2:1] overlaps with hi_byte[1:0], so we need to set them consistently + # hi_byte=0x55 means bits 8-15 = 0b01010101, so slot bits 8-9 = 0b01 + # slot bit 7 = 1, so slot = 0b011 = 3 + reg = (0xDEADBEEF << 32) | (0xCAFE << 16) | (0x55 << 8) | (1 << 7) | 0b1001 + pkt = REG.from_raw(reg) + self.assertEqual(pkt.slot, 0b011) # bit7=1, bits 8-9 from hi_byte low 2 bits = 01 + self.assertEqual(pkt.hi_byte, 0x55) + self.assertEqual(pkt.subop, 0xCAFE) + self.assertEqual(pkt.val32, 0xDEADBEEF) + + +class TestRoundtrip(unittest.TestCase): + """Test encode/decode roundtrip.""" + + def test_simple_roundtrip(self): + """Test encode/decode roundtrip preserves packet types.""" + test_packets = [ + LAYOUT_HEADER.from_raw(0x100), + WAVESTART.from_raw(0x0), + INST.from_raw(0x10), + INST.from_raw(0x10), + WAVEEND.from_raw(0x40), + ] + encoded = encode(test_packets) + decoded = decode(encoded) + + self.assertGreaterEqual(len(decoded), len(test_packets)) + for i, (orig, dec) in enumerate(zip(test_packets, decoded)): + self.assertEqual(type(orig), type(dec), f"type mismatch at {i}") + + +if __name__ == "__main__": + unittest.main() diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index b3a45dec96..f5145fce0a 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -21,7 +21,7 @@ from tinygrad.runtime.support.memory import AddrSpace if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401 # pylint: disable=unused-import SQTT = ContextVar("SQTT", abs(VIZ.value)>=2) -SQTT_ITRACE_SE_MASK, SQTT_LIMIT_SE = ContextVar("SQTT_ITRACE_SE_MASK", 0b11), ContextVar("SQTT_LIMIT_SE", 0) +SQTT_ITRACE_SE_MASK, SQTT_LIMIT_SE, SQTT_SIMD_SEL = ContextVar("SQTT_ITRACE_SE_MASK", 0b11), ContextVar("SQTT_LIMIT_SE", 0), ContextVar("SQTT_SIMD_SEL", 0) PMC = ContextVar("PMC", abs(VIZ.value)>=2) EVENT_INDEX_PARTIAL_FLUSH = 4 # based on a comment in nvd.h WAIT_REG_MEM_FUNCTION_EQ = 3 # == @@ -251,14 +251,15 @@ class AMDComputeQueue(HWQueue): else: self.wreg(self.gc.regSQ_THREAD_TRACE_BUF0_SIZE, base_hi=buf0_hi, size=buf0s[se].size >> 12) self.wreg(self.gc.regSQ_THREAD_TRACE_BUF0_BASE, base_lo=buf0_lo) - # NOTE: SQTT can only trace instructions on one simd per se, this selects first simd in first wgp in first sa. + # NOTE: SQTT can only trace instructions on one simd per se, this selects the simd in first wgp in first sa. # For RGP to display instruction trace it has to see it on first SE. Howerver ACE/MEC/whatever does the dispatching starting with second se, # and on amdgpu/non-AM it also does weird things with dispatch order inside se: around 7 times out of 10 it starts from the last cu, but # sometimes not, especially if the kernel has more than one wavefront which means that kernels with small global size might get unlucky and # be dispatched on something else and not be seen in instruction tracing tab. You can force the wavefronts of a kernel to be dispatched on the # CUs you want to by disabling other CUs via bits in regCOMPUTE_STATIC_THREAD_MGMT_SE and trace even kernels that only have one wavefront. + # Use SQTT_SIMD_SEL (0-3) to select which SIMD to trace within the WGP. cs_wtype = (1 << 6) if self.dev.target >= (12,0,0) else self.soc.SQ_TT_WTYPE_INCLUDE_CS_BIT - self.wreg(self.gc.regSQ_THREAD_TRACE_MASK, wtype_include=cs_wtype, simd_sel=0, wgp_sel=0, sa_sel=0) + self.wreg(self.gc.regSQ_THREAD_TRACE_MASK, wtype_include=cs_wtype, simd_sel=SQTT_SIMD_SEL.value, wgp_sel=0, sa_sel=0) reg_include = self.soc.SQ_TT_TOKEN_MASK_SQDEC_BIT | self.soc.SQ_TT_TOKEN_MASK_SHDEC_BIT | self.soc.SQ_TT_TOKEN_MASK_GFXUDEC_BIT | \ self.soc.SQ_TT_TOKEN_MASK_COMP_BIT | self.soc.SQ_TT_TOKEN_MASK_CONTEXT_BIT token_exclude = (1 << self.soc.SQ_TT_TOKEN_EXCLUDE_PERF_SHIFT) if self.dev.target < (12,0,0) else 0