diff --git a/extra/assembly/amd/dsl.py b/extra/assembly/amd/dsl.py index fe8e803197..c857295543 100644 --- a/extra/assembly/amd/dsl.py +++ b/extra/assembly/amd/dsl.py @@ -15,6 +15,13 @@ class Reg: def __init__(self, offset: int = 0, sz: int = 512, *, neg: bool = False, abs_: bool = False, hi: bool = False): self.offset, self.sz = offset, sz self.neg, self.abs_, self.hi = neg, abs_, hi + + # TODO: remove these legacy aliases + @property + def count(self): return self.sz + @property + def idx(self): return self.offset + def __hash__(self): return hash((self.offset, self.sz, self.neg, self.abs_, self.hi)) def __getitem__(self, key): if isinstance(key, slice): diff --git a/extra/assembly/amd/sqtt.py b/extra/assembly/amd/sqtt.py index 032a446dff..33c5106e51 100644 --- a/extra/assembly/amd/sqtt.py +++ b/extra/assembly/amd/sqtt.py @@ -5,8 +5,9 @@ The format is nibble-based with variable-width packets determined by a state mac Uses BitField infrastructure from dsl.py, similar to GPU instruction encoding. """ from __future__ import annotations +from typing import Iterator from enum import Enum -from extra.assembly.amd.dsl import BitField, EnumBitField, FixedBitField, bits +from extra.assembly.amd.dsl import BitField, FixedBitField, bits # ═══════════════════════════════════════════════════════════════════════════════ # FIELD ENUMS @@ -302,8 +303,6 @@ STATE_TO_OPCODE, OPCODE_TO_CLASS = _build_state_table() # Precompute special case opcodes _TS_DELTA_OR_MARK_OPCODE = next(op for op, cls in OPCODE_TO_CLASS.items() if cls is TS_DELTA_OR_MARK) _TS_DELTA_SHORT_OPCODE = next(op for op, cls in OPCODE_TO_CLASS.items() if cls is TS_DELTA_SHORT) -_TS_DELTA_OR_MARK_BIT8 = (TS_DELTA_OR_MARK.bit8.lo, TS_DELTA_OR_MARK.bit8.mask()) -_TS_DELTA_OR_MARK_BIT9 = (TS_DELTA_OR_MARK.bit9.lo, TS_DELTA_OR_MARK.bit9.mask()) # Combined lookup: opcode -> (pkt_cls, nib_count, delta_lo, delta_mask, special_case) # special_case: 0=none, 1=TS_DELTA_OR_MARK, 2=TS_DELTA_SHORT @@ -319,41 +318,69 @@ for _opcode, _pkt_cls in OPCODE_TO_CLASS.items(): # DECODER # ═══════════════════════════════════════════════════════════════════════════════ -def decode(data: bytes) -> list[PacketType]: - """Decode raw SQTT blob into list of packet instances.""" - packets: list[PacketType] = [] - packets_append = packets.append - n = len(data) - reg = 0 - offset = 0 - nib_count = 16 - time = 0 - state_to_opcode = STATE_TO_OPCODE - decode_info = _DECODE_INFO - mask64 = (1 << 64) - 1 +def decode(data: bytes) -> Iterator[PacketType]: + """Decode raw SQTT blob, yielding packet instances.""" + n, reg, pos, nib_off, nib_count, time = len(data), 0, 0, 0, 16, 0 - while (offset >> 3) < n: - target = offset + nib_count * 4 - while offset < target and (offset >> 3) < n: - byte = data[offset >> 3] - nib = (byte >> (offset & 4)) & 0xF - reg = ((reg >> 4) | (nib << 60)) & mask64 - offset += 4 - if offset < target: break - - opcode = state_to_opcode[reg & 0xFF] - pkt_cls, nib_count, delta_lo, delta_mask, special = decode_info[opcode] + while pos + ((nib_count + nib_off + 1) >> 1) <= n: + need = nib_count - nib_off + # 1. if unaligned, read high nibble to align + if nib_off: reg, pos = (reg >> 4) | ((data[pos] >> 4) << 60), pos + 1 + # 2. read all full bytes at once + if (byte_count := need >> 1): + chunk = int.from_bytes(data[pos:pos + byte_count], 'little') + reg, pos = (reg >> (byte_count * 8)) | (chunk << (64 - byte_count * 8)), pos + byte_count + # 3. if odd, read low nibble + if (nib_off := need & 1): reg = (reg >> 4) | ((data[pos] & 0xF) << 60) + opcode = STATE_TO_OPCODE[reg & 0xFF] + pkt_cls, nib_count, delta_lo, delta_mask, special = _DECODE_INFO[opcode] delta = (reg >> delta_lo) & delta_mask - - if special == 1: # TS_DELTA_OR_MARK - bit8 = (reg >> _TS_DELTA_OR_MARK_BIT8[0]) & _TS_DELTA_OR_MARK_BIT8[1] - bit9 = (reg >> _TS_DELTA_OR_MARK_BIT9[0]) & _TS_DELTA_OR_MARK_BIT9[1] - if bit9 and not bit8: delta = 0 - elif special == 2: # TS_DELTA_SHORT - delta = delta + 8 - + if special == 1 and (reg >> 9) & 1 and not (reg >> 8) & 1: delta = 0 # TS_DELTA_OR_MARK marker + elif special == 2: delta += 8 # TS_DELTA_SHORT time += delta - packets_append(pkt_cls.from_raw(reg, time)) + yield pkt_cls.from_raw(reg, time) - return packets +# ═══════════════════════════════════════════════════════════════════════════════ +# PRINTER +# ═══════════════════════════════════════════════════════════════════════════════ + +PACKET_COLORS = { + "INST": "WHITE", "VALUINST": "BLACK", "VMEMEXEC": "yellow", "ALUEXEC": "yellow", + "IMMEDIATE": "YELLOW", "IMMEDIATE_MASK": "YELLOW", "WAVERDY": "cyan", "WAVEALLOC": "cyan", + "WAVEEND": "blue", "WAVESTART": "blue", "PERF": "magenta", "EVENT": "red", "EVENT_BIG": "red", + "REG": "green", "LAYOUT_HEADER": "white", "SNAPSHOT": "white", "UTILCTR": "green", +} + +def format_packet(p) -> str: + from tinygrad.helpers import colored + name = type(p).__name__ + if isinstance(p, INST): + op_name = p.op.name if isinstance(p.op, InstOp) else f"0x{p.op:02x}" + fields = f"wave={p.wave} op={op_name}" + (" flag1" if p.flag1 else "") + (" flag2" if p.flag2 else "") + elif isinstance(p, VALUINST): fields = f"wave={p.wave}" + (" flag" if p.flag else "") + elif isinstance(p, ALUEXEC): fields = f"src={p.src.name if isinstance(p.src, AluSrc) else p.src}" + elif isinstance(p, VMEMEXEC): fields = f"src={p.src.name if isinstance(p.src, MemSrc) else p.src}" + elif isinstance(p, (WAVESTART, WAVEEND)): fields = f"wave={p.wave} simd={p.simd} cu={p.cu}" + elif hasattr(p, '_fields'): + fields = " ".join(f"{k}=0x{getattr(p, k):x}" if k in {'snap', 'val32'} else f"{k}={getattr(p, k)}" + for k in p._fields if not k.startswith('_') and k not in {'delta', 'encoding'}) + else: fields = "" + return f"{p._time:8}: {colored(f'{name:18}', PACKET_COLORS.get(name, 'white'))} {fields}" + +def print_packets(packets) -> None: + skip = {"NOP", "TS_DELTA_SHORT", "TS_WAVE_STATE", "TS_DELTA_OR_MARK", "TS_DELTA_S5_W2", "TS_DELTA_S5_W3", "TS_DELTA_S8_W3", "REG", "EVENT"} + for p in packets: + if type(p).__name__ not in skip: print(format_packet(p)) + +if __name__ == "__main__": + import sys, pickle + if len(sys.argv) < 2: + print("Usage: python sqtt.py ") + sys.exit(1) + with open(sys.argv[1], "rb") as f: + data = pickle.load(f) + sqtt_events = [e for e in data if type(e).__name__ == "ProfileSQTTEvent"] + for i, event in enumerate(sqtt_events): + print(f"\n=== event {i} ===") + print_packets(decode(event.blob)) diff --git a/extra/assembly/amd/test/test_sqtt_examples.py b/extra/assembly/amd/test/test_sqtt_examples.py index 7e6214fe45..afbec5a4e5 100644 --- a/extra/assembly/amd/test/test_sqtt_examples.py +++ b/extra/assembly/amd/test/test_sqtt_examples.py @@ -2,14 +2,14 @@ """Tests for SQTT packet decoding using real captured examples.""" import pickle, unittest, ctypes, threading from pathlib import Path -from tinygrad.helpers import DEBUG, colored +from tinygrad.helpers import DEBUG from tinygrad.runtime.autogen import rocprof from tinygrad.runtime.support.elf import elf_loader from extra.assembly.amd.decode import decode_inst from extra.assembly.amd.autogen.rdna3.ins import SOPP from extra.assembly.amd.autogen.rdna3.enum import SOPPOp from extra.assembly.amd.sqtt import (decode, LAYOUT_HEADER, WAVESTART, WAVEEND, INST, VALUINST, IMMEDIATE, IMMEDIATE_MASK, - ALUEXEC, VMEMEXEC, PACKET_TYPES, InstOp, AluSrc, MemSrc) + ALUEXEC, VMEMEXEC, PACKET_TYPES, InstOp, print_packets) EXAMPLES_DIR = Path(__file__).parent.parent.parent.parent / "sqtt/examples" # INST ops for non-traced SIMDs (excluded from instruction count) @@ -19,34 +19,6 @@ OTHER_SIMD_OPS = {InstOp.OTHER_LDS_LOAD, InstOp.OTHER_LDS_STORE, InstOp.OTHER_LD InstOp.OTHER_GLOBAL_STORE_64, InstOp.OTHER_GLOBAL_STORE_96, InstOp.OTHER_GLOBAL_STORE_128, InstOp.OTHER_GLOBAL_STORE_VADDR_128} -PACKET_COLORS = { - "INST": "WHITE", "VALUINST": "BLACK", "VMEMEXEC": "yellow", "ALUEXEC": "yellow", - "IMMEDIATE": "YELLOW", "IMMEDIATE_MASK": "YELLOW", "WAVERDY": "cyan", "WAVEALLOC": "cyan", - "WAVEEND": "blue", "WAVESTART": "blue", "PERF": "magenta", "EVENT": "red", "EVENT_BIG": "red", - "REG": "green", "LAYOUT_HEADER": "white", "SNAPSHOT": "white", "UTILCTR": "green", -} - -def format_packet(p, time_offset: int = 0) -> str: - name, cycle = type(p).__name__, p._time - time_offset - if isinstance(p, INST): - op_name = p.op.name if isinstance(p.op, InstOp) else f"0x{p.op:02x}" - fields = f"wave={p.wave} op={op_name}" + (" flag1" if p.flag1 else "") + (" flag2" if p.flag2 else "") - elif isinstance(p, VALUINST): fields = f"wave={p.wave}" + (" flag" if p.flag else "") - elif isinstance(p, ALUEXEC): fields = f"src={p.src.name if isinstance(p.src, AluSrc) else p.src}" - elif isinstance(p, VMEMEXEC): fields = f"src={p.src.name if isinstance(p.src, MemSrc) else p.src}" - elif isinstance(p, (WAVESTART, WAVEEND)): fields = f"wave={p.wave} simd={p.simd} cu={p.cu}" - elif hasattr(p, '_values'): - fields = " ".join(f"{k}=0x{v:x}" if k in {'snap', 'val32'} else f"{k}={v}" - for k, v in p._values.items() if not k.startswith('_') and k != 'delta') - else: fields = "" - return f"{cycle:8}: {colored(f'{name:18}', PACKET_COLORS.get(name, 'white'))} {fields}" - -def print_packets(packets: list) -> None: - skip = {"NOP", "TS_DELTA_SHORT", "TS_WAVE_STATE", "TS_DELTA_OR_MARK", "TS_DELTA_S5_W2", "TS_DELTA_S5_W3", "TS_DELTA_S8_W3", "REG", "EVENT"} - time_offset = packets[0]._time if packets else 0 - for p in packets: - if type(p).__name__ not in skip: print(format_packet(p, time_offset)) - # ═══════════════════════════════════════════════════════════════════════════════ # ROCPROF DECODER # ═══════════════════════════════════════════════════════════════════════════════ @@ -135,7 +107,7 @@ class TestSQTTExamples(unittest.TestCase): for name, (events, *_) in self.examples.items(): for i, event in enumerate(events): with self.subTest(example=name, event=i): - packets = decode(event.blob) + packets = list(decode(event.blob)) if DEBUG >= 2: print(f"\n=== {name} event {i} ==="); print_packets(packets) self.assertGreater(len(packets), 0, f"no packets decoded from {name} event {i}") self.assertIsInstance(packets[0], LAYOUT_HEADER, f"first packet should be LAYOUT_HEADER in {name}") @@ -169,6 +141,20 @@ class TestSQTTExamples(unittest.TestCase): all_packets = [p for e in events for p in decode(e.blob)] self.assertGreater(len([p for p in all_packets if isinstance(p, INST)]), 0, f"no INST packets in {name}") + def test_packet_counts(self): + expected = { + "profile_empty_run_0": [559, 600], + "profile_empty_run_1": [517, 570], + "profile_gemm_run_0": [1489, 604, 1789, 466, 17570, 407], + "profile_gemm_run_1": [1453, 604, 1871, 493, 17827, 460], + "profile_plus_run_0": [695, 668], + "profile_plus_run_1": [663, 593], + } + for name, (events, *_) in self.examples.items(): + with self.subTest(example=name): + counts = [len(list(decode(e.blob))) for e in events] + self.assertEqual(counts, expected[name], f"packet count mismatch in {name}") + def test_rocprof_wave_times_match(self): """Wave start/end times must match rocprof exactly.""" for name, (events, lib, base) in self.examples.items(): @@ -184,9 +170,8 @@ class TestSQTTExamples(unittest.TestCase): # extract from our decoder our_waves: list[tuple[int, int]] = [] for event in events: - packets = decode(event.blob) wave_starts: dict[tuple[int, int, int], int] = {} - for p in packets: + for p in decode(event.blob): if isinstance(p, WAVESTART): wave_starts[(p.wave, p.simd, p.cu)] = p._time elif isinstance(p, WAVEEND) and (key := (p.wave, p.simd, p.cu)) in wave_starts: our_waves.append((wave_starts[key], p._time)) diff --git a/extra/gemm/amd_asm_matmul.py b/extra/gemm/amd_asm_matmul.py index 19bf39aca1..2ef09f491e 100644 --- a/extra/gemm/amd_asm_matmul.py +++ b/extra/gemm/amd_asm_matmul.py @@ -13,7 +13,7 @@ from pathlib import Path from tinygrad import Tensor, Device, Context, GlobalCounters from tinygrad.helpers import getenv, colored from tinygrad.engine.realize import Runner, Estimates, ExecItem -from extra.assembly.amd.dsl import s, v, VCC_LO, EXEC_LO, NULL +from extra.assembly.amd.dsl import s, v, VCC_LO, NULL from extra.assembly.amd.autogen.rdna3.ins import * # ============================================================================= @@ -198,8 +198,8 @@ class Kernel: def global_load(self, vdst, addr, saddr=None): """Global load b32""" - self.emit(global_load_b32(vdst=v[vdst], addr=v[addr:addr+1], - saddr=s[saddr:saddr+2] if saddr else NULL)) + self.emit(global_load_b32(vdst=v[vdst], addr=v[addr] if saddr else v[addr:addr+1], + saddr=s[saddr:saddr+1] if saddr else NULL)) def waitcnt(self, lgkm=None, vm=None): """Wait for memory operations. lgkm=N waits until N lgkm ops remain, vm=N waits until N vmem ops remain.""" @@ -408,8 +408,6 @@ def build_kernel(arch='gfx1100'): for i, idx in enumerate([6,7,8,9,10,13,14,15]): offset = i * 64 k.emit(ds_store_b32(addr=v[V_LDS_B_ADDR], data0=v[INIT_TILE_LOADS[idx][0]], offset0=offset & 0xFF, offset1=offset >> 8)) - k.waitcnt(lgkm=0) - k.barrier() # =========================================================================== # INIT: Compute LDS base addresses, then zero accumulators @@ -450,16 +448,21 @@ def build_kernel(arch='gfx1100'): k.emit(s_cselect_b32(s[S_PREFETCH_FLAG], -1, 0)) # s_cselect doesn't modify SCC k.emit(s_cbranch_scc0(simm16=0)); k.branch_to('SKIP_PREFETCH') # branch if loop_ctr >= loop_bound - # Advance prefetch pointers - k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], 0x20000, v[V_GLOBAL_B_ADDR])) - k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], 0x20, v[V_GLOBAL_A_ADDR])) - if not NO_GLOBAL: + # Advance prefetch pointers + k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], 0x20000, v[V_GLOBAL_B_ADDR])) + k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], 0x20, v[V_GLOBAL_A_ADDR])) + for vdst, saddr_lo in INIT_PREFETCH: k.global_load(vdst, V_GLOBAL_B_ADDR, saddr_lo) k.label('SKIP_PREFETCH') + # wait for local stores to finish (either initial or loop) + # then sync the warp so it's safe to load local + k.waitcnt(lgkm=0) + k.barrier() + # 8 inner loop iterations for iter in range(8): # Load A tile (4 pairs) and B tile (8 pairs) from LDS @@ -476,7 +479,7 @@ def build_kernel(arch='gfx1100'): k.waitcnt(lgkm=0) # 64 dual FMACs - k.emit(s_clause(simm16=63)) + k.emit(s_clause(simm16=len(FMAC_PATTERN)-1)) for i, (vdst_x, vdst_y, ax, bx, ay, by) in enumerate(FMAC_PATTERN): k.emit(VOPD(VOPDOp.V_DUAL_FMAC_F32, VOPDOp.V_DUAL_FMAC_F32, vdstx=v[vdst_x], vdsty=v[vdst_y], srcx0=v[ax], vsrcx1=v[bx], srcy0=v[ay], vsrcy1=v[by])) @@ -487,10 +490,10 @@ def build_kernel(arch='gfx1100'): k.global_load(vdst1, addr, slo1) k.global_load(vdst2, addr, slo2) - k.emit(s_and_not1_b32(VCC_LO, EXEC_LO, s[S_PREFETCH_FLAG])) + # wait for all global stores to finish + # then sync the warp so it's safe to store local k.waitcnt(vm=0) k.barrier() - k.emit(s_cbranch_vccnz(simm16=0)); k.branch_to('LOOP_INC') # Store prefetched data to LDS # NOTE: Register naming reflects LDS tile organization, not source matrix: @@ -503,8 +506,6 @@ def build_kernel(arch='gfx1100'): offset = i * 64 k.emit(ds_store_b32(addr=v[V_LDS_B_ADDR], data0=v[V_LDS_B_DATA[i]], offset0=offset & 0xFF, offset1=offset >> 8)) - k.waitcnt(lgkm=0) - k.barrier() k.emit(s_branch(simm16=0)); k.branch_to('LOOP_INC') # =========================================================================== @@ -565,7 +566,7 @@ def build_kernel(arch='gfx1100'): k.emit(global_store_b128(addr=v[0:1], data=v[tmp:tmp+3], saddr=NULL)) - k.emit(s_sendmsg(simm16=3)) + k.emit(s_sendmsg(simm16=3)) # DEALLOC_VGPRS k.emit(s_endpgm()) return k.to_asm() diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index 57f1a38c3a..120ee30c3f 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -21,7 +21,8 @@ from tinygrad.runtime.support.memory import AddrSpace if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401 # pylint: disable=unused-import SQTT = ContextVar("SQTT", abs(VIZ.value)>=2) -SQTT_ITRACE_SE_MASK, SQTT_LIMIT_SE = ContextVar("SQTT_ITRACE_SE_MASK", 0b11), ContextVar("SQTT_LIMIT_SE", 0) +SQTT_ITRACE_SE_MASK, SQTT_LIMIT_SE, SQTT_SIMD_SEL, SQTT_TOKEN_EXCLUDE = \ + ContextVar("SQTT_ITRACE_SE_MASK", 0b11), ContextVar("SQTT_LIMIT_SE", 0), ContextVar("SQTT_SIMD_SEL", 0), ContextVar("SQTT_TOKEN_EXCLUDE", 0) PMC = ContextVar("PMC", abs(VIZ.value)>=2) EVENT_INDEX_PARTIAL_FLUSH = 4 # based on a comment in nvd.h WAIT_REG_MEM_FUNCTION_EQ = 3 # == @@ -252,17 +253,18 @@ class AMDComputeQueue(HWQueue): else: self.wreg(self.gc.regSQ_THREAD_TRACE_BUF0_SIZE, base_hi=buf0_hi, size=buf0s[se].size >> 12) self.wreg(self.gc.regSQ_THREAD_TRACE_BUF0_BASE, base_lo=buf0_lo) - # NOTE: SQTT can only trace instructions on one simd per se, this selects first simd in first wgp in first sa. + # NOTE: SQTT can only trace instructions on one simd per se, this selects the simd in first wgp in first sa. # For RGP to display instruction trace it has to see it on first SE. Howerver ACE/MEC/whatever does the dispatching starting with second se, # and on amdgpu/non-AM it also does weird things with dispatch order inside se: around 7 times out of 10 it starts from the last cu, but # sometimes not, especially if the kernel has more than one wavefront which means that kernels with small global size might get unlucky and # be dispatched on something else and not be seen in instruction tracing tab. You can force the wavefronts of a kernel to be dispatched on the # CUs you want to by disabling other CUs via bits in regCOMPUTE_STATIC_THREAD_MGMT_SE and trace even kernels that only have one wavefront. + # Use SQTT_SIMD_SEL to select which SIMD to trace (0-3). Memory ops show different InstOp values (0x2x vs 0x5x) based on SIMD. cs_wtype = (1 << 6) if self.dev.target >= (12,0,0) else self.soc.SQ_TT_WTYPE_INCLUDE_CS_BIT - self.wreg(self.gc.regSQ_THREAD_TRACE_MASK, wtype_include=cs_wtype, simd_sel=0, wgp_sel=0, sa_sel=0) + self.wreg(self.gc.regSQ_THREAD_TRACE_MASK, wtype_include=cs_wtype, simd_sel=SQTT_SIMD_SEL.value, wgp_sel=0, sa_sel=0) reg_include = self.soc.SQ_TT_TOKEN_MASK_SQDEC_BIT | self.soc.SQ_TT_TOKEN_MASK_SHDEC_BIT | self.soc.SQ_TT_TOKEN_MASK_GFXUDEC_BIT | \ self.soc.SQ_TT_TOKEN_MASK_COMP_BIT | self.soc.SQ_TT_TOKEN_MASK_CONTEXT_BIT - token_exclude = (1 << self.soc.SQ_TT_TOKEN_EXCLUDE_PERF_SHIFT) if self.dev.target < (12,0,0) else 0 + token_exclude = SQTT_TOKEN_EXCLUDE.value | ((1 << self.soc.SQ_TT_TOKEN_EXCLUDE_PERF_SHIFT) if self.dev.target < (12,0,0) else 0) # disable instr tracing if not (SQTT_ITRACE_SE_MASK.value >> se) & 0b1: