mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-13 08:58:05 -05:00
assembly/amd: 2% faster amd_uop_matmul + SQTT (#14122)
* assembly/amd: 2% faster amd_uop_matmul * SQTT_TOKEN_EXCLUDE + SQTT_SIMD_SEL * sqtt printer * fix printer * fast decode * fast decoder * test packet counts * ugh it's not faster * dead
This commit is contained in:
@@ -15,6 +15,13 @@ class Reg:
|
||||
def __init__(self, offset: int = 0, sz: int = 512, *, neg: bool = False, abs_: bool = False, hi: bool = False):
|
||||
self.offset, self.sz = offset, sz
|
||||
self.neg, self.abs_, self.hi = neg, abs_, hi
|
||||
|
||||
# TODO: remove these legacy aliases
|
||||
@property
|
||||
def count(self): return self.sz
|
||||
@property
|
||||
def idx(self): return self.offset
|
||||
|
||||
def __hash__(self): return hash((self.offset, self.sz, self.neg, self.abs_, self.hi))
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, slice):
|
||||
|
||||
@@ -5,8 +5,9 @@ The format is nibble-based with variable-width packets determined by a state mac
|
||||
Uses BitField infrastructure from dsl.py, similar to GPU instruction encoding.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from typing import Iterator
|
||||
from enum import Enum
|
||||
from extra.assembly.amd.dsl import BitField, EnumBitField, FixedBitField, bits
|
||||
from extra.assembly.amd.dsl import BitField, FixedBitField, bits
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# FIELD ENUMS
|
||||
@@ -302,8 +303,6 @@ STATE_TO_OPCODE, OPCODE_TO_CLASS = _build_state_table()
|
||||
# Precompute special case opcodes
|
||||
_TS_DELTA_OR_MARK_OPCODE = next(op for op, cls in OPCODE_TO_CLASS.items() if cls is TS_DELTA_OR_MARK)
|
||||
_TS_DELTA_SHORT_OPCODE = next(op for op, cls in OPCODE_TO_CLASS.items() if cls is TS_DELTA_SHORT)
|
||||
_TS_DELTA_OR_MARK_BIT8 = (TS_DELTA_OR_MARK.bit8.lo, TS_DELTA_OR_MARK.bit8.mask())
|
||||
_TS_DELTA_OR_MARK_BIT9 = (TS_DELTA_OR_MARK.bit9.lo, TS_DELTA_OR_MARK.bit9.mask())
|
||||
|
||||
# Combined lookup: opcode -> (pkt_cls, nib_count, delta_lo, delta_mask, special_case)
|
||||
# special_case: 0=none, 1=TS_DELTA_OR_MARK, 2=TS_DELTA_SHORT
|
||||
@@ -319,41 +318,69 @@ for _opcode, _pkt_cls in OPCODE_TO_CLASS.items():
|
||||
# DECODER
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def decode(data: bytes) -> list[PacketType]:
|
||||
"""Decode raw SQTT blob into list of packet instances."""
|
||||
packets: list[PacketType] = []
|
||||
packets_append = packets.append
|
||||
n = len(data)
|
||||
reg = 0
|
||||
offset = 0
|
||||
nib_count = 16
|
||||
time = 0
|
||||
state_to_opcode = STATE_TO_OPCODE
|
||||
decode_info = _DECODE_INFO
|
||||
mask64 = (1 << 64) - 1
|
||||
def decode(data: bytes) -> Iterator[PacketType]:
|
||||
"""Decode raw SQTT blob, yielding packet instances."""
|
||||
n, reg, pos, nib_off, nib_count, time = len(data), 0, 0, 0, 16, 0
|
||||
|
||||
while (offset >> 3) < n:
|
||||
target = offset + nib_count * 4
|
||||
while offset < target and (offset >> 3) < n:
|
||||
byte = data[offset >> 3]
|
||||
nib = (byte >> (offset & 4)) & 0xF
|
||||
reg = ((reg >> 4) | (nib << 60)) & mask64
|
||||
offset += 4
|
||||
if offset < target: break
|
||||
|
||||
opcode = state_to_opcode[reg & 0xFF]
|
||||
pkt_cls, nib_count, delta_lo, delta_mask, special = decode_info[opcode]
|
||||
while pos + ((nib_count + nib_off + 1) >> 1) <= n:
|
||||
need = nib_count - nib_off
|
||||
# 1. if unaligned, read high nibble to align
|
||||
if nib_off: reg, pos = (reg >> 4) | ((data[pos] >> 4) << 60), pos + 1
|
||||
# 2. read all full bytes at once
|
||||
if (byte_count := need >> 1):
|
||||
chunk = int.from_bytes(data[pos:pos + byte_count], 'little')
|
||||
reg, pos = (reg >> (byte_count * 8)) | (chunk << (64 - byte_count * 8)), pos + byte_count
|
||||
# 3. if odd, read low nibble
|
||||
if (nib_off := need & 1): reg = (reg >> 4) | ((data[pos] & 0xF) << 60)
|
||||
|
||||
opcode = STATE_TO_OPCODE[reg & 0xFF]
|
||||
pkt_cls, nib_count, delta_lo, delta_mask, special = _DECODE_INFO[opcode]
|
||||
delta = (reg >> delta_lo) & delta_mask
|
||||
|
||||
if special == 1: # TS_DELTA_OR_MARK
|
||||
bit8 = (reg >> _TS_DELTA_OR_MARK_BIT8[0]) & _TS_DELTA_OR_MARK_BIT8[1]
|
||||
bit9 = (reg >> _TS_DELTA_OR_MARK_BIT9[0]) & _TS_DELTA_OR_MARK_BIT9[1]
|
||||
if bit9 and not bit8: delta = 0
|
||||
elif special == 2: # TS_DELTA_SHORT
|
||||
delta = delta + 8
|
||||
|
||||
if special == 1 and (reg >> 9) & 1 and not (reg >> 8) & 1: delta = 0 # TS_DELTA_OR_MARK marker
|
||||
elif special == 2: delta += 8 # TS_DELTA_SHORT
|
||||
time += delta
|
||||
packets_append(pkt_cls.from_raw(reg, time))
|
||||
yield pkt_cls.from_raw(reg, time)
|
||||
|
||||
return packets
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# PRINTER
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
PACKET_COLORS = {
|
||||
"INST": "WHITE", "VALUINST": "BLACK", "VMEMEXEC": "yellow", "ALUEXEC": "yellow",
|
||||
"IMMEDIATE": "YELLOW", "IMMEDIATE_MASK": "YELLOW", "WAVERDY": "cyan", "WAVEALLOC": "cyan",
|
||||
"WAVEEND": "blue", "WAVESTART": "blue", "PERF": "magenta", "EVENT": "red", "EVENT_BIG": "red",
|
||||
"REG": "green", "LAYOUT_HEADER": "white", "SNAPSHOT": "white", "UTILCTR": "green",
|
||||
}
|
||||
|
||||
def format_packet(p) -> str:
|
||||
from tinygrad.helpers import colored
|
||||
name = type(p).__name__
|
||||
if isinstance(p, INST):
|
||||
op_name = p.op.name if isinstance(p.op, InstOp) else f"0x{p.op:02x}"
|
||||
fields = f"wave={p.wave} op={op_name}" + (" flag1" if p.flag1 else "") + (" flag2" if p.flag2 else "")
|
||||
elif isinstance(p, VALUINST): fields = f"wave={p.wave}" + (" flag" if p.flag else "")
|
||||
elif isinstance(p, ALUEXEC): fields = f"src={p.src.name if isinstance(p.src, AluSrc) else p.src}"
|
||||
elif isinstance(p, VMEMEXEC): fields = f"src={p.src.name if isinstance(p.src, MemSrc) else p.src}"
|
||||
elif isinstance(p, (WAVESTART, WAVEEND)): fields = f"wave={p.wave} simd={p.simd} cu={p.cu}"
|
||||
elif hasattr(p, '_fields'):
|
||||
fields = " ".join(f"{k}=0x{getattr(p, k):x}" if k in {'snap', 'val32'} else f"{k}={getattr(p, k)}"
|
||||
for k in p._fields if not k.startswith('_') and k not in {'delta', 'encoding'})
|
||||
else: fields = ""
|
||||
return f"{p._time:8}: {colored(f'{name:18}', PACKET_COLORS.get(name, 'white'))} {fields}"
|
||||
|
||||
def print_packets(packets) -> None:
|
||||
skip = {"NOP", "TS_DELTA_SHORT", "TS_WAVE_STATE", "TS_DELTA_OR_MARK", "TS_DELTA_S5_W2", "TS_DELTA_S5_W3", "TS_DELTA_S8_W3", "REG", "EVENT"}
|
||||
for p in packets:
|
||||
if type(p).__name__ not in skip: print(format_packet(p))
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys, pickle
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python sqtt.py <pkl_file>")
|
||||
sys.exit(1)
|
||||
with open(sys.argv[1], "rb") as f:
|
||||
data = pickle.load(f)
|
||||
sqtt_events = [e for e in data if type(e).__name__ == "ProfileSQTTEvent"]
|
||||
for i, event in enumerate(sqtt_events):
|
||||
print(f"\n=== event {i} ===")
|
||||
print_packets(decode(event.blob))
|
||||
|
||||
@@ -2,14 +2,14 @@
|
||||
"""Tests for SQTT packet decoding using real captured examples."""
|
||||
import pickle, unittest, ctypes, threading
|
||||
from pathlib import Path
|
||||
from tinygrad.helpers import DEBUG, colored
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.runtime.autogen import rocprof
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
from extra.assembly.amd.decode import decode_inst
|
||||
from extra.assembly.amd.autogen.rdna3.ins import SOPP
|
||||
from extra.assembly.amd.autogen.rdna3.enum import SOPPOp
|
||||
from extra.assembly.amd.sqtt import (decode, LAYOUT_HEADER, WAVESTART, WAVEEND, INST, VALUINST, IMMEDIATE, IMMEDIATE_MASK,
|
||||
ALUEXEC, VMEMEXEC, PACKET_TYPES, InstOp, AluSrc, MemSrc)
|
||||
ALUEXEC, VMEMEXEC, PACKET_TYPES, InstOp, print_packets)
|
||||
|
||||
EXAMPLES_DIR = Path(__file__).parent.parent.parent.parent / "sqtt/examples"
|
||||
# INST ops for non-traced SIMDs (excluded from instruction count)
|
||||
@@ -19,34 +19,6 @@ OTHER_SIMD_OPS = {InstOp.OTHER_LDS_LOAD, InstOp.OTHER_LDS_STORE, InstOp.OTHER_LD
|
||||
InstOp.OTHER_GLOBAL_STORE_64, InstOp.OTHER_GLOBAL_STORE_96, InstOp.OTHER_GLOBAL_STORE_128,
|
||||
InstOp.OTHER_GLOBAL_STORE_VADDR_128}
|
||||
|
||||
PACKET_COLORS = {
|
||||
"INST": "WHITE", "VALUINST": "BLACK", "VMEMEXEC": "yellow", "ALUEXEC": "yellow",
|
||||
"IMMEDIATE": "YELLOW", "IMMEDIATE_MASK": "YELLOW", "WAVERDY": "cyan", "WAVEALLOC": "cyan",
|
||||
"WAVEEND": "blue", "WAVESTART": "blue", "PERF": "magenta", "EVENT": "red", "EVENT_BIG": "red",
|
||||
"REG": "green", "LAYOUT_HEADER": "white", "SNAPSHOT": "white", "UTILCTR": "green",
|
||||
}
|
||||
|
||||
def format_packet(p, time_offset: int = 0) -> str:
|
||||
name, cycle = type(p).__name__, p._time - time_offset
|
||||
if isinstance(p, INST):
|
||||
op_name = p.op.name if isinstance(p.op, InstOp) else f"0x{p.op:02x}"
|
||||
fields = f"wave={p.wave} op={op_name}" + (" flag1" if p.flag1 else "") + (" flag2" if p.flag2 else "")
|
||||
elif isinstance(p, VALUINST): fields = f"wave={p.wave}" + (" flag" if p.flag else "")
|
||||
elif isinstance(p, ALUEXEC): fields = f"src={p.src.name if isinstance(p.src, AluSrc) else p.src}"
|
||||
elif isinstance(p, VMEMEXEC): fields = f"src={p.src.name if isinstance(p.src, MemSrc) else p.src}"
|
||||
elif isinstance(p, (WAVESTART, WAVEEND)): fields = f"wave={p.wave} simd={p.simd} cu={p.cu}"
|
||||
elif hasattr(p, '_values'):
|
||||
fields = " ".join(f"{k}=0x{v:x}" if k in {'snap', 'val32'} else f"{k}={v}"
|
||||
for k, v in p._values.items() if not k.startswith('_') and k != 'delta')
|
||||
else: fields = ""
|
||||
return f"{cycle:8}: {colored(f'{name:18}', PACKET_COLORS.get(name, 'white'))} {fields}"
|
||||
|
||||
def print_packets(packets: list) -> None:
|
||||
skip = {"NOP", "TS_DELTA_SHORT", "TS_WAVE_STATE", "TS_DELTA_OR_MARK", "TS_DELTA_S5_W2", "TS_DELTA_S5_W3", "TS_DELTA_S8_W3", "REG", "EVENT"}
|
||||
time_offset = packets[0]._time if packets else 0
|
||||
for p in packets:
|
||||
if type(p).__name__ not in skip: print(format_packet(p, time_offset))
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# ROCPROF DECODER
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
@@ -135,7 +107,7 @@ class TestSQTTExamples(unittest.TestCase):
|
||||
for name, (events, *_) in self.examples.items():
|
||||
for i, event in enumerate(events):
|
||||
with self.subTest(example=name, event=i):
|
||||
packets = decode(event.blob)
|
||||
packets = list(decode(event.blob))
|
||||
if DEBUG >= 2: print(f"\n=== {name} event {i} ==="); print_packets(packets)
|
||||
self.assertGreater(len(packets), 0, f"no packets decoded from {name} event {i}")
|
||||
self.assertIsInstance(packets[0], LAYOUT_HEADER, f"first packet should be LAYOUT_HEADER in {name}")
|
||||
@@ -169,6 +141,20 @@ class TestSQTTExamples(unittest.TestCase):
|
||||
all_packets = [p for e in events for p in decode(e.blob)]
|
||||
self.assertGreater(len([p for p in all_packets if isinstance(p, INST)]), 0, f"no INST packets in {name}")
|
||||
|
||||
def test_packet_counts(self):
|
||||
expected = {
|
||||
"profile_empty_run_0": [559, 600],
|
||||
"profile_empty_run_1": [517, 570],
|
||||
"profile_gemm_run_0": [1489, 604, 1789, 466, 17570, 407],
|
||||
"profile_gemm_run_1": [1453, 604, 1871, 493, 17827, 460],
|
||||
"profile_plus_run_0": [695, 668],
|
||||
"profile_plus_run_1": [663, 593],
|
||||
}
|
||||
for name, (events, *_) in self.examples.items():
|
||||
with self.subTest(example=name):
|
||||
counts = [len(list(decode(e.blob))) for e in events]
|
||||
self.assertEqual(counts, expected[name], f"packet count mismatch in {name}")
|
||||
|
||||
def test_rocprof_wave_times_match(self):
|
||||
"""Wave start/end times must match rocprof exactly."""
|
||||
for name, (events, lib, base) in self.examples.items():
|
||||
@@ -184,9 +170,8 @@ class TestSQTTExamples(unittest.TestCase):
|
||||
# extract from our decoder
|
||||
our_waves: list[tuple[int, int]] = []
|
||||
for event in events:
|
||||
packets = decode(event.blob)
|
||||
wave_starts: dict[tuple[int, int, int], int] = {}
|
||||
for p in packets:
|
||||
for p in decode(event.blob):
|
||||
if isinstance(p, WAVESTART): wave_starts[(p.wave, p.simd, p.cu)] = p._time
|
||||
elif isinstance(p, WAVEEND) and (key := (p.wave, p.simd, p.cu)) in wave_starts:
|
||||
our_waves.append((wave_starts[key], p._time))
|
||||
|
||||
@@ -13,7 +13,7 @@ from pathlib import Path
|
||||
from tinygrad import Tensor, Device, Context, GlobalCounters
|
||||
from tinygrad.helpers import getenv, colored
|
||||
from tinygrad.engine.realize import Runner, Estimates, ExecItem
|
||||
from extra.assembly.amd.dsl import s, v, VCC_LO, EXEC_LO, NULL
|
||||
from extra.assembly.amd.dsl import s, v, VCC_LO, NULL
|
||||
from extra.assembly.amd.autogen.rdna3.ins import *
|
||||
|
||||
# =============================================================================
|
||||
@@ -198,8 +198,8 @@ class Kernel:
|
||||
|
||||
def global_load(self, vdst, addr, saddr=None):
|
||||
"""Global load b32"""
|
||||
self.emit(global_load_b32(vdst=v[vdst], addr=v[addr:addr+1],
|
||||
saddr=s[saddr:saddr+2] if saddr else NULL))
|
||||
self.emit(global_load_b32(vdst=v[vdst], addr=v[addr] if saddr else v[addr:addr+1],
|
||||
saddr=s[saddr:saddr+1] if saddr else NULL))
|
||||
|
||||
def waitcnt(self, lgkm=None, vm=None):
|
||||
"""Wait for memory operations. lgkm=N waits until N lgkm ops remain, vm=N waits until N vmem ops remain."""
|
||||
@@ -408,8 +408,6 @@ def build_kernel(arch='gfx1100'):
|
||||
for i, idx in enumerate([6,7,8,9,10,13,14,15]):
|
||||
offset = i * 64
|
||||
k.emit(ds_store_b32(addr=v[V_LDS_B_ADDR], data0=v[INIT_TILE_LOADS[idx][0]], offset0=offset & 0xFF, offset1=offset >> 8))
|
||||
k.waitcnt(lgkm=0)
|
||||
k.barrier()
|
||||
|
||||
# ===========================================================================
|
||||
# INIT: Compute LDS base addresses, then zero accumulators
|
||||
@@ -450,16 +448,21 @@ def build_kernel(arch='gfx1100'):
|
||||
k.emit(s_cselect_b32(s[S_PREFETCH_FLAG], -1, 0)) # s_cselect doesn't modify SCC
|
||||
k.emit(s_cbranch_scc0(simm16=0)); k.branch_to('SKIP_PREFETCH') # branch if loop_ctr >= loop_bound
|
||||
|
||||
# Advance prefetch pointers
|
||||
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], 0x20000, v[V_GLOBAL_B_ADDR]))
|
||||
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], 0x20, v[V_GLOBAL_A_ADDR]))
|
||||
|
||||
if not NO_GLOBAL:
|
||||
# Advance prefetch pointers
|
||||
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], 0x20000, v[V_GLOBAL_B_ADDR]))
|
||||
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], 0x20, v[V_GLOBAL_A_ADDR]))
|
||||
|
||||
for vdst, saddr_lo in INIT_PREFETCH:
|
||||
k.global_load(vdst, V_GLOBAL_B_ADDR, saddr_lo)
|
||||
|
||||
k.label('SKIP_PREFETCH')
|
||||
|
||||
# wait for local stores to finish (either initial or loop)
|
||||
# then sync the warp so it's safe to load local
|
||||
k.waitcnt(lgkm=0)
|
||||
k.barrier()
|
||||
|
||||
# 8 inner loop iterations
|
||||
for iter in range(8):
|
||||
# Load A tile (4 pairs) and B tile (8 pairs) from LDS
|
||||
@@ -476,7 +479,7 @@ def build_kernel(arch='gfx1100'):
|
||||
k.waitcnt(lgkm=0)
|
||||
|
||||
# 64 dual FMACs
|
||||
k.emit(s_clause(simm16=63))
|
||||
k.emit(s_clause(simm16=len(FMAC_PATTERN)-1))
|
||||
for i, (vdst_x, vdst_y, ax, bx, ay, by) in enumerate(FMAC_PATTERN):
|
||||
k.emit(VOPD(VOPDOp.V_DUAL_FMAC_F32, VOPDOp.V_DUAL_FMAC_F32,
|
||||
vdstx=v[vdst_x], vdsty=v[vdst_y], srcx0=v[ax], vsrcx1=v[bx], srcy0=v[ay], vsrcy1=v[by]))
|
||||
@@ -487,10 +490,10 @@ def build_kernel(arch='gfx1100'):
|
||||
k.global_load(vdst1, addr, slo1)
|
||||
k.global_load(vdst2, addr, slo2)
|
||||
|
||||
k.emit(s_and_not1_b32(VCC_LO, EXEC_LO, s[S_PREFETCH_FLAG]))
|
||||
# wait for all global stores to finish
|
||||
# then sync the warp so it's safe to store local
|
||||
k.waitcnt(vm=0)
|
||||
k.barrier()
|
||||
k.emit(s_cbranch_vccnz(simm16=0)); k.branch_to('LOOP_INC')
|
||||
|
||||
# Store prefetched data to LDS
|
||||
# NOTE: Register naming reflects LDS tile organization, not source matrix:
|
||||
@@ -503,8 +506,6 @@ def build_kernel(arch='gfx1100'):
|
||||
offset = i * 64
|
||||
k.emit(ds_store_b32(addr=v[V_LDS_B_ADDR], data0=v[V_LDS_B_DATA[i]], offset0=offset & 0xFF, offset1=offset >> 8))
|
||||
|
||||
k.waitcnt(lgkm=0)
|
||||
k.barrier()
|
||||
k.emit(s_branch(simm16=0)); k.branch_to('LOOP_INC')
|
||||
|
||||
# ===========================================================================
|
||||
@@ -565,7 +566,7 @@ def build_kernel(arch='gfx1100'):
|
||||
|
||||
k.emit(global_store_b128(addr=v[0:1], data=v[tmp:tmp+3], saddr=NULL))
|
||||
|
||||
k.emit(s_sendmsg(simm16=3))
|
||||
k.emit(s_sendmsg(simm16=3)) # DEALLOC_VGPRS
|
||||
k.emit(s_endpgm())
|
||||
|
||||
return k.to_asm()
|
||||
|
||||
@@ -21,7 +21,8 @@ from tinygrad.runtime.support.memory import AddrSpace
|
||||
if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
SQTT = ContextVar("SQTT", abs(VIZ.value)>=2)
|
||||
SQTT_ITRACE_SE_MASK, SQTT_LIMIT_SE = ContextVar("SQTT_ITRACE_SE_MASK", 0b11), ContextVar("SQTT_LIMIT_SE", 0)
|
||||
SQTT_ITRACE_SE_MASK, SQTT_LIMIT_SE, SQTT_SIMD_SEL, SQTT_TOKEN_EXCLUDE = \
|
||||
ContextVar("SQTT_ITRACE_SE_MASK", 0b11), ContextVar("SQTT_LIMIT_SE", 0), ContextVar("SQTT_SIMD_SEL", 0), ContextVar("SQTT_TOKEN_EXCLUDE", 0)
|
||||
PMC = ContextVar("PMC", abs(VIZ.value)>=2)
|
||||
EVENT_INDEX_PARTIAL_FLUSH = 4 # based on a comment in nvd.h
|
||||
WAIT_REG_MEM_FUNCTION_EQ = 3 # ==
|
||||
@@ -252,17 +253,18 @@ class AMDComputeQueue(HWQueue):
|
||||
else:
|
||||
self.wreg(self.gc.regSQ_THREAD_TRACE_BUF0_SIZE, base_hi=buf0_hi, size=buf0s[se].size >> 12)
|
||||
self.wreg(self.gc.regSQ_THREAD_TRACE_BUF0_BASE, base_lo=buf0_lo)
|
||||
# NOTE: SQTT can only trace instructions on one simd per se, this selects first simd in first wgp in first sa.
|
||||
# NOTE: SQTT can only trace instructions on one simd per se, this selects the simd in first wgp in first sa.
|
||||
# For RGP to display instruction trace it has to see it on first SE. Howerver ACE/MEC/whatever does the dispatching starting with second se,
|
||||
# and on amdgpu/non-AM it also does weird things with dispatch order inside se: around 7 times out of 10 it starts from the last cu, but
|
||||
# sometimes not, especially if the kernel has more than one wavefront which means that kernels with small global size might get unlucky and
|
||||
# be dispatched on something else and not be seen in instruction tracing tab. You can force the wavefronts of a kernel to be dispatched on the
|
||||
# CUs you want to by disabling other CUs via bits in regCOMPUTE_STATIC_THREAD_MGMT_SE<x> and trace even kernels that only have one wavefront.
|
||||
# Use SQTT_SIMD_SEL to select which SIMD to trace (0-3). Memory ops show different InstOp values (0x2x vs 0x5x) based on SIMD.
|
||||
cs_wtype = (1 << 6) if self.dev.target >= (12,0,0) else self.soc.SQ_TT_WTYPE_INCLUDE_CS_BIT
|
||||
self.wreg(self.gc.regSQ_THREAD_TRACE_MASK, wtype_include=cs_wtype, simd_sel=0, wgp_sel=0, sa_sel=0)
|
||||
self.wreg(self.gc.regSQ_THREAD_TRACE_MASK, wtype_include=cs_wtype, simd_sel=SQTT_SIMD_SEL.value, wgp_sel=0, sa_sel=0)
|
||||
reg_include = self.soc.SQ_TT_TOKEN_MASK_SQDEC_BIT | self.soc.SQ_TT_TOKEN_MASK_SHDEC_BIT | self.soc.SQ_TT_TOKEN_MASK_GFXUDEC_BIT | \
|
||||
self.soc.SQ_TT_TOKEN_MASK_COMP_BIT | self.soc.SQ_TT_TOKEN_MASK_CONTEXT_BIT
|
||||
token_exclude = (1 << self.soc.SQ_TT_TOKEN_EXCLUDE_PERF_SHIFT) if self.dev.target < (12,0,0) else 0
|
||||
token_exclude = SQTT_TOKEN_EXCLUDE.value | ((1 << self.soc.SQ_TT_TOKEN_EXCLUDE_PERF_SHIFT) if self.dev.target < (12,0,0) else 0)
|
||||
|
||||
# disable instr tracing
|
||||
if not (SQTT_ITRACE_SE_MASK.value >> se) & 0b1:
|
||||
|
||||
Reference in New Issue
Block a user