assembly/amd: 2% faster amd_uop_matmul + SQTT (#14122)

* assembly/amd: 2% faster amd_uop_matmul

* SQTT_TOKEN_EXCLUDE + SQTT_SIMD_SEL

* sqtt printer

* fix printer

* fast decode

* fast decoder

* test packet counts

* ugh it's not faster

* dead
This commit is contained in:
George Hotz
2026-01-13 19:55:32 +09:00
committed by GitHub
parent 6cd318e377
commit a28c8105a5
5 changed files with 110 additions and 88 deletions

View File

@@ -15,6 +15,13 @@ class Reg:
def __init__(self, offset: int = 0, sz: int = 512, *, neg: bool = False, abs_: bool = False, hi: bool = False):
self.offset, self.sz = offset, sz
self.neg, self.abs_, self.hi = neg, abs_, hi
# TODO: remove these legacy aliases
@property
def count(self): return self.sz
@property
def idx(self): return self.offset
def __hash__(self): return hash((self.offset, self.sz, self.neg, self.abs_, self.hi))
def __getitem__(self, key):
if isinstance(key, slice):

View File

@@ -5,8 +5,9 @@ The format is nibble-based with variable-width packets determined by a state mac
Uses BitField infrastructure from dsl.py, similar to GPU instruction encoding.
"""
from __future__ import annotations
from typing import Iterator
from enum import Enum
from extra.assembly.amd.dsl import BitField, EnumBitField, FixedBitField, bits
from extra.assembly.amd.dsl import BitField, FixedBitField, bits
# ═══════════════════════════════════════════════════════════════════════════════
# FIELD ENUMS
@@ -302,8 +303,6 @@ STATE_TO_OPCODE, OPCODE_TO_CLASS = _build_state_table()
# Precompute special case opcodes
_TS_DELTA_OR_MARK_OPCODE = next(op for op, cls in OPCODE_TO_CLASS.items() if cls is TS_DELTA_OR_MARK)
_TS_DELTA_SHORT_OPCODE = next(op for op, cls in OPCODE_TO_CLASS.items() if cls is TS_DELTA_SHORT)
_TS_DELTA_OR_MARK_BIT8 = (TS_DELTA_OR_MARK.bit8.lo, TS_DELTA_OR_MARK.bit8.mask())
_TS_DELTA_OR_MARK_BIT9 = (TS_DELTA_OR_MARK.bit9.lo, TS_DELTA_OR_MARK.bit9.mask())
# Combined lookup: opcode -> (pkt_cls, nib_count, delta_lo, delta_mask, special_case)
# special_case: 0=none, 1=TS_DELTA_OR_MARK, 2=TS_DELTA_SHORT
@@ -319,41 +318,69 @@ for _opcode, _pkt_cls in OPCODE_TO_CLASS.items():
# DECODER
# ═══════════════════════════════════════════════════════════════════════════════
def decode(data: bytes) -> list[PacketType]:
"""Decode raw SQTT blob into list of packet instances."""
packets: list[PacketType] = []
packets_append = packets.append
n = len(data)
reg = 0
offset = 0
nib_count = 16
time = 0
state_to_opcode = STATE_TO_OPCODE
decode_info = _DECODE_INFO
mask64 = (1 << 64) - 1
def decode(data: bytes) -> Iterator[PacketType]:
"""Decode raw SQTT blob, yielding packet instances."""
n, reg, pos, nib_off, nib_count, time = len(data), 0, 0, 0, 16, 0
while (offset >> 3) < n:
target = offset + nib_count * 4
while offset < target and (offset >> 3) < n:
byte = data[offset >> 3]
nib = (byte >> (offset & 4)) & 0xF
reg = ((reg >> 4) | (nib << 60)) & mask64
offset += 4
if offset < target: break
opcode = state_to_opcode[reg & 0xFF]
pkt_cls, nib_count, delta_lo, delta_mask, special = decode_info[opcode]
while pos + ((nib_count + nib_off + 1) >> 1) <= n:
need = nib_count - nib_off
# 1. if unaligned, read high nibble to align
if nib_off: reg, pos = (reg >> 4) | ((data[pos] >> 4) << 60), pos + 1
# 2. read all full bytes at once
if (byte_count := need >> 1):
chunk = int.from_bytes(data[pos:pos + byte_count], 'little')
reg, pos = (reg >> (byte_count * 8)) | (chunk << (64 - byte_count * 8)), pos + byte_count
# 3. if odd, read low nibble
if (nib_off := need & 1): reg = (reg >> 4) | ((data[pos] & 0xF) << 60)
opcode = STATE_TO_OPCODE[reg & 0xFF]
pkt_cls, nib_count, delta_lo, delta_mask, special = _DECODE_INFO[opcode]
delta = (reg >> delta_lo) & delta_mask
if special == 1: # TS_DELTA_OR_MARK
bit8 = (reg >> _TS_DELTA_OR_MARK_BIT8[0]) & _TS_DELTA_OR_MARK_BIT8[1]
bit9 = (reg >> _TS_DELTA_OR_MARK_BIT9[0]) & _TS_DELTA_OR_MARK_BIT9[1]
if bit9 and not bit8: delta = 0
elif special == 2: # TS_DELTA_SHORT
delta = delta + 8
if special == 1 and (reg >> 9) & 1 and not (reg >> 8) & 1: delta = 0 # TS_DELTA_OR_MARK marker
elif special == 2: delta += 8 # TS_DELTA_SHORT
time += delta
packets_append(pkt_cls.from_raw(reg, time))
yield pkt_cls.from_raw(reg, time)
return packets
# ═══════════════════════════════════════════════════════════════════════════════
# PRINTER
# ═══════════════════════════════════════════════════════════════════════════════
PACKET_COLORS = {
"INST": "WHITE", "VALUINST": "BLACK", "VMEMEXEC": "yellow", "ALUEXEC": "yellow",
"IMMEDIATE": "YELLOW", "IMMEDIATE_MASK": "YELLOW", "WAVERDY": "cyan", "WAVEALLOC": "cyan",
"WAVEEND": "blue", "WAVESTART": "blue", "PERF": "magenta", "EVENT": "red", "EVENT_BIG": "red",
"REG": "green", "LAYOUT_HEADER": "white", "SNAPSHOT": "white", "UTILCTR": "green",
}
def format_packet(p) -> str:
from tinygrad.helpers import colored
name = type(p).__name__
if isinstance(p, INST):
op_name = p.op.name if isinstance(p.op, InstOp) else f"0x{p.op:02x}"
fields = f"wave={p.wave} op={op_name}" + (" flag1" if p.flag1 else "") + (" flag2" if p.flag2 else "")
elif isinstance(p, VALUINST): fields = f"wave={p.wave}" + (" flag" if p.flag else "")
elif isinstance(p, ALUEXEC): fields = f"src={p.src.name if isinstance(p.src, AluSrc) else p.src}"
elif isinstance(p, VMEMEXEC): fields = f"src={p.src.name if isinstance(p.src, MemSrc) else p.src}"
elif isinstance(p, (WAVESTART, WAVEEND)): fields = f"wave={p.wave} simd={p.simd} cu={p.cu}"
elif hasattr(p, '_fields'):
fields = " ".join(f"{k}=0x{getattr(p, k):x}" if k in {'snap', 'val32'} else f"{k}={getattr(p, k)}"
for k in p._fields if not k.startswith('_') and k not in {'delta', 'encoding'})
else: fields = ""
return f"{p._time:8}: {colored(f'{name:18}', PACKET_COLORS.get(name, 'white'))} {fields}"
def print_packets(packets) -> None:
skip = {"NOP", "TS_DELTA_SHORT", "TS_WAVE_STATE", "TS_DELTA_OR_MARK", "TS_DELTA_S5_W2", "TS_DELTA_S5_W3", "TS_DELTA_S8_W3", "REG", "EVENT"}
for p in packets:
if type(p).__name__ not in skip: print(format_packet(p))
if __name__ == "__main__":
import sys, pickle
if len(sys.argv) < 2:
print("Usage: python sqtt.py <pkl_file>")
sys.exit(1)
with open(sys.argv[1], "rb") as f:
data = pickle.load(f)
sqtt_events = [e for e in data if type(e).__name__ == "ProfileSQTTEvent"]
for i, event in enumerate(sqtt_events):
print(f"\n=== event {i} ===")
print_packets(decode(event.blob))

View File

@@ -2,14 +2,14 @@
"""Tests for SQTT packet decoding using real captured examples."""
import pickle, unittest, ctypes, threading
from pathlib import Path
from tinygrad.helpers import DEBUG, colored
from tinygrad.helpers import DEBUG
from tinygrad.runtime.autogen import rocprof
from tinygrad.runtime.support.elf import elf_loader
from extra.assembly.amd.decode import decode_inst
from extra.assembly.amd.autogen.rdna3.ins import SOPP
from extra.assembly.amd.autogen.rdna3.enum import SOPPOp
from extra.assembly.amd.sqtt import (decode, LAYOUT_HEADER, WAVESTART, WAVEEND, INST, VALUINST, IMMEDIATE, IMMEDIATE_MASK,
ALUEXEC, VMEMEXEC, PACKET_TYPES, InstOp, AluSrc, MemSrc)
ALUEXEC, VMEMEXEC, PACKET_TYPES, InstOp, print_packets)
EXAMPLES_DIR = Path(__file__).parent.parent.parent.parent / "sqtt/examples"
# INST ops for non-traced SIMDs (excluded from instruction count)
@@ -19,34 +19,6 @@ OTHER_SIMD_OPS = {InstOp.OTHER_LDS_LOAD, InstOp.OTHER_LDS_STORE, InstOp.OTHER_LD
InstOp.OTHER_GLOBAL_STORE_64, InstOp.OTHER_GLOBAL_STORE_96, InstOp.OTHER_GLOBAL_STORE_128,
InstOp.OTHER_GLOBAL_STORE_VADDR_128}
PACKET_COLORS = {
"INST": "WHITE", "VALUINST": "BLACK", "VMEMEXEC": "yellow", "ALUEXEC": "yellow",
"IMMEDIATE": "YELLOW", "IMMEDIATE_MASK": "YELLOW", "WAVERDY": "cyan", "WAVEALLOC": "cyan",
"WAVEEND": "blue", "WAVESTART": "blue", "PERF": "magenta", "EVENT": "red", "EVENT_BIG": "red",
"REG": "green", "LAYOUT_HEADER": "white", "SNAPSHOT": "white", "UTILCTR": "green",
}
def format_packet(p, time_offset: int = 0) -> str:
name, cycle = type(p).__name__, p._time - time_offset
if isinstance(p, INST):
op_name = p.op.name if isinstance(p.op, InstOp) else f"0x{p.op:02x}"
fields = f"wave={p.wave} op={op_name}" + (" flag1" if p.flag1 else "") + (" flag2" if p.flag2 else "")
elif isinstance(p, VALUINST): fields = f"wave={p.wave}" + (" flag" if p.flag else "")
elif isinstance(p, ALUEXEC): fields = f"src={p.src.name if isinstance(p.src, AluSrc) else p.src}"
elif isinstance(p, VMEMEXEC): fields = f"src={p.src.name if isinstance(p.src, MemSrc) else p.src}"
elif isinstance(p, (WAVESTART, WAVEEND)): fields = f"wave={p.wave} simd={p.simd} cu={p.cu}"
elif hasattr(p, '_values'):
fields = " ".join(f"{k}=0x{v:x}" if k in {'snap', 'val32'} else f"{k}={v}"
for k, v in p._values.items() if not k.startswith('_') and k != 'delta')
else: fields = ""
return f"{cycle:8}: {colored(f'{name:18}', PACKET_COLORS.get(name, 'white'))} {fields}"
def print_packets(packets: list) -> None:
skip = {"NOP", "TS_DELTA_SHORT", "TS_WAVE_STATE", "TS_DELTA_OR_MARK", "TS_DELTA_S5_W2", "TS_DELTA_S5_W3", "TS_DELTA_S8_W3", "REG", "EVENT"}
time_offset = packets[0]._time if packets else 0
for p in packets:
if type(p).__name__ not in skip: print(format_packet(p, time_offset))
# ═══════════════════════════════════════════════════════════════════════════════
# ROCPROF DECODER
# ═══════════════════════════════════════════════════════════════════════════════
@@ -135,7 +107,7 @@ class TestSQTTExamples(unittest.TestCase):
for name, (events, *_) in self.examples.items():
for i, event in enumerate(events):
with self.subTest(example=name, event=i):
packets = decode(event.blob)
packets = list(decode(event.blob))
if DEBUG >= 2: print(f"\n=== {name} event {i} ==="); print_packets(packets)
self.assertGreater(len(packets), 0, f"no packets decoded from {name} event {i}")
self.assertIsInstance(packets[0], LAYOUT_HEADER, f"first packet should be LAYOUT_HEADER in {name}")
@@ -169,6 +141,20 @@ class TestSQTTExamples(unittest.TestCase):
all_packets = [p for e in events for p in decode(e.blob)]
self.assertGreater(len([p for p in all_packets if isinstance(p, INST)]), 0, f"no INST packets in {name}")
def test_packet_counts(self):
expected = {
"profile_empty_run_0": [559, 600],
"profile_empty_run_1": [517, 570],
"profile_gemm_run_0": [1489, 604, 1789, 466, 17570, 407],
"profile_gemm_run_1": [1453, 604, 1871, 493, 17827, 460],
"profile_plus_run_0": [695, 668],
"profile_plus_run_1": [663, 593],
}
for name, (events, *_) in self.examples.items():
with self.subTest(example=name):
counts = [len(list(decode(e.blob))) for e in events]
self.assertEqual(counts, expected[name], f"packet count mismatch in {name}")
def test_rocprof_wave_times_match(self):
"""Wave start/end times must match rocprof exactly."""
for name, (events, lib, base) in self.examples.items():
@@ -184,9 +170,8 @@ class TestSQTTExamples(unittest.TestCase):
# extract from our decoder
our_waves: list[tuple[int, int]] = []
for event in events:
packets = decode(event.blob)
wave_starts: dict[tuple[int, int, int], int] = {}
for p in packets:
for p in decode(event.blob):
if isinstance(p, WAVESTART): wave_starts[(p.wave, p.simd, p.cu)] = p._time
elif isinstance(p, WAVEEND) and (key := (p.wave, p.simd, p.cu)) in wave_starts:
our_waves.append((wave_starts[key], p._time))

View File

@@ -13,7 +13,7 @@ from pathlib import Path
from tinygrad import Tensor, Device, Context, GlobalCounters
from tinygrad.helpers import getenv, colored
from tinygrad.engine.realize import Runner, Estimates, ExecItem
from extra.assembly.amd.dsl import s, v, VCC_LO, EXEC_LO, NULL
from extra.assembly.amd.dsl import s, v, VCC_LO, NULL
from extra.assembly.amd.autogen.rdna3.ins import *
# =============================================================================
@@ -198,8 +198,8 @@ class Kernel:
def global_load(self, vdst, addr, saddr=None):
"""Global load b32"""
self.emit(global_load_b32(vdst=v[vdst], addr=v[addr:addr+1],
saddr=s[saddr:saddr+2] if saddr else NULL))
self.emit(global_load_b32(vdst=v[vdst], addr=v[addr] if saddr else v[addr:addr+1],
saddr=s[saddr:saddr+1] if saddr else NULL))
def waitcnt(self, lgkm=None, vm=None):
"""Wait for memory operations. lgkm=N waits until N lgkm ops remain, vm=N waits until N vmem ops remain."""
@@ -408,8 +408,6 @@ def build_kernel(arch='gfx1100'):
for i, idx in enumerate([6,7,8,9,10,13,14,15]):
offset = i * 64
k.emit(ds_store_b32(addr=v[V_LDS_B_ADDR], data0=v[INIT_TILE_LOADS[idx][0]], offset0=offset & 0xFF, offset1=offset >> 8))
k.waitcnt(lgkm=0)
k.barrier()
# ===========================================================================
# INIT: Compute LDS base addresses, then zero accumulators
@@ -450,16 +448,21 @@ def build_kernel(arch='gfx1100'):
k.emit(s_cselect_b32(s[S_PREFETCH_FLAG], -1, 0)) # s_cselect doesn't modify SCC
k.emit(s_cbranch_scc0(simm16=0)); k.branch_to('SKIP_PREFETCH') # branch if loop_ctr >= loop_bound
# Advance prefetch pointers
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], 0x20000, v[V_GLOBAL_B_ADDR]))
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], 0x20, v[V_GLOBAL_A_ADDR]))
if not NO_GLOBAL:
# Advance prefetch pointers
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], 0x20000, v[V_GLOBAL_B_ADDR]))
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], 0x20, v[V_GLOBAL_A_ADDR]))
for vdst, saddr_lo in INIT_PREFETCH:
k.global_load(vdst, V_GLOBAL_B_ADDR, saddr_lo)
k.label('SKIP_PREFETCH')
# wait for local stores to finish (either initial or loop)
# then sync the warp so it's safe to load local
k.waitcnt(lgkm=0)
k.barrier()
# 8 inner loop iterations
for iter in range(8):
# Load A tile (4 pairs) and B tile (8 pairs) from LDS
@@ -476,7 +479,7 @@ def build_kernel(arch='gfx1100'):
k.waitcnt(lgkm=0)
# 64 dual FMACs
k.emit(s_clause(simm16=63))
k.emit(s_clause(simm16=len(FMAC_PATTERN)-1))
for i, (vdst_x, vdst_y, ax, bx, ay, by) in enumerate(FMAC_PATTERN):
k.emit(VOPD(VOPDOp.V_DUAL_FMAC_F32, VOPDOp.V_DUAL_FMAC_F32,
vdstx=v[vdst_x], vdsty=v[vdst_y], srcx0=v[ax], vsrcx1=v[bx], srcy0=v[ay], vsrcy1=v[by]))
@@ -487,10 +490,10 @@ def build_kernel(arch='gfx1100'):
k.global_load(vdst1, addr, slo1)
k.global_load(vdst2, addr, slo2)
k.emit(s_and_not1_b32(VCC_LO, EXEC_LO, s[S_PREFETCH_FLAG]))
# wait for all global stores to finish
# then sync the warp so it's safe to store local
k.waitcnt(vm=0)
k.barrier()
k.emit(s_cbranch_vccnz(simm16=0)); k.branch_to('LOOP_INC')
# Store prefetched data to LDS
# NOTE: Register naming reflects LDS tile organization, not source matrix:
@@ -503,8 +506,6 @@ def build_kernel(arch='gfx1100'):
offset = i * 64
k.emit(ds_store_b32(addr=v[V_LDS_B_ADDR], data0=v[V_LDS_B_DATA[i]], offset0=offset & 0xFF, offset1=offset >> 8))
k.waitcnt(lgkm=0)
k.barrier()
k.emit(s_branch(simm16=0)); k.branch_to('LOOP_INC')
# ===========================================================================
@@ -565,7 +566,7 @@ def build_kernel(arch='gfx1100'):
k.emit(global_store_b128(addr=v[0:1], data=v[tmp:tmp+3], saddr=NULL))
k.emit(s_sendmsg(simm16=3))
k.emit(s_sendmsg(simm16=3)) # DEALLOC_VGPRS
k.emit(s_endpgm())
return k.to_asm()

View File

@@ -21,7 +21,8 @@ from tinygrad.runtime.support.memory import AddrSpace
if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401 # pylint: disable=unused-import
SQTT = ContextVar("SQTT", abs(VIZ.value)>=2)
SQTT_ITRACE_SE_MASK, SQTT_LIMIT_SE = ContextVar("SQTT_ITRACE_SE_MASK", 0b11), ContextVar("SQTT_LIMIT_SE", 0)
SQTT_ITRACE_SE_MASK, SQTT_LIMIT_SE, SQTT_SIMD_SEL, SQTT_TOKEN_EXCLUDE = \
ContextVar("SQTT_ITRACE_SE_MASK", 0b11), ContextVar("SQTT_LIMIT_SE", 0), ContextVar("SQTT_SIMD_SEL", 0), ContextVar("SQTT_TOKEN_EXCLUDE", 0)
PMC = ContextVar("PMC", abs(VIZ.value)>=2)
EVENT_INDEX_PARTIAL_FLUSH = 4 # based on a comment in nvd.h
WAIT_REG_MEM_FUNCTION_EQ = 3 # ==
@@ -252,17 +253,18 @@ class AMDComputeQueue(HWQueue):
else:
self.wreg(self.gc.regSQ_THREAD_TRACE_BUF0_SIZE, base_hi=buf0_hi, size=buf0s[se].size >> 12)
self.wreg(self.gc.regSQ_THREAD_TRACE_BUF0_BASE, base_lo=buf0_lo)
# NOTE: SQTT can only trace instructions on one simd per se, this selects first simd in first wgp in first sa.
# NOTE: SQTT can only trace instructions on one simd per se, this selects the simd in first wgp in first sa.
# For RGP to display instruction trace it has to see it on first SE. Howerver ACE/MEC/whatever does the dispatching starting with second se,
# and on amdgpu/non-AM it also does weird things with dispatch order inside se: around 7 times out of 10 it starts from the last cu, but
# sometimes not, especially if the kernel has more than one wavefront which means that kernels with small global size might get unlucky and
# be dispatched on something else and not be seen in instruction tracing tab. You can force the wavefronts of a kernel to be dispatched on the
# CUs you want to by disabling other CUs via bits in regCOMPUTE_STATIC_THREAD_MGMT_SE<x> and trace even kernels that only have one wavefront.
# Use SQTT_SIMD_SEL to select which SIMD to trace (0-3). Memory ops show different InstOp values (0x2x vs 0x5x) based on SIMD.
cs_wtype = (1 << 6) if self.dev.target >= (12,0,0) else self.soc.SQ_TT_WTYPE_INCLUDE_CS_BIT
self.wreg(self.gc.regSQ_THREAD_TRACE_MASK, wtype_include=cs_wtype, simd_sel=0, wgp_sel=0, sa_sel=0)
self.wreg(self.gc.regSQ_THREAD_TRACE_MASK, wtype_include=cs_wtype, simd_sel=SQTT_SIMD_SEL.value, wgp_sel=0, sa_sel=0)
reg_include = self.soc.SQ_TT_TOKEN_MASK_SQDEC_BIT | self.soc.SQ_TT_TOKEN_MASK_SHDEC_BIT | self.soc.SQ_TT_TOKEN_MASK_GFXUDEC_BIT | \
self.soc.SQ_TT_TOKEN_MASK_COMP_BIT | self.soc.SQ_TT_TOKEN_MASK_CONTEXT_BIT
token_exclude = (1 << self.soc.SQ_TT_TOKEN_EXCLUDE_PERF_SHIFT) if self.dev.target < (12,0,0) else 0
token_exclude = SQTT_TOKEN_EXCLUDE.value | ((1 << self.soc.SQ_TT_TOKEN_EXCLUDE_PERF_SHIFT) if self.dev.target < (12,0,0) else 0)
# disable instr tracing
if not (SQTT_ITRACE_SE_MASK.value >> se) & 0b1: