mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
sqtt: new packet types, add discovery script (#14960)
This commit is contained in:
148
extra/sqtt/examples/discover_ops.py
Normal file
148
extra/sqtt/examples/discover_ops.py
Normal file
@@ -0,0 +1,148 @@
|
||||
#!/usr/bin/env python3
|
||||
# Run all ALU and memory instructions in the ISA
|
||||
import functools, inspect
|
||||
from enum import Enum
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AddrSpace
|
||||
from tinygrad.renderer.amd.dsl import Inst, Reg, OPERANDS, SrcField, VGPRField, SGPRField, SSrcField, SBaseField, AlignedSGPRField, BitField
|
||||
from tinygrad.renderer.amd.dsl import FixedBitField, EnumBitField, s, v, NULL, VCC_LO
|
||||
from extra.gemm.amd_asm_matmul import Kernel
|
||||
|
||||
# skip instructions that mutate wave state (PC, EXEC, allocations, signals)
|
||||
SKIP = {"S_SETPC_B64", "S_SWAPPC_B64", "S_RFE_B64", "S_BARRIER_SIGNAL_ISFIRST", "S_GET_BARRIER_STATE", "S_ALLOC_VGPR", "S_SLEEP_VAR", "S_GETPC_B64",
|
||||
"S_SENDMSG_RTN_B32", "S_SENDMSG_RTN_B64"}
|
||||
# skip barriers, s_waits, wrap level atomics, and ray tracing (bvh)
|
||||
SKIP_SUBSTR = ["SAVEEXEC", "CMPX", "WREXEC", "MOVREL", "ATOMIC", "S_BUFFER_", "S_ATC_PROBE", "BARRIER", "S_WAITCNT", "BVH",
|
||||
"DS_CMPSTORE_RTN", "DS_WRAP_RTN_B32", "DS_ORDERED_COUNT", "DS_GWS", "GS_REG", "GLOBAL_LOAD_LDS", "GLOBAL_STORE_BLOCK"]
|
||||
|
||||
ALU_FORMATS = {"VOP1", "VOP1_LIT", "VOP1_SDST", "VOP2", "VOP2_LIT", "VOP3", "VOP3_SDST", "VOP3SD", "VOP3P", "VOP3P_MFMA", "VOP3PX2",
|
||||
"VOPC", "SOP1", "SOP1_LIT", "SOP2", "SOP2_LIT", "SOPC", "SOPC_LIT", "SOPK", "SOPK_LIT", "VINTERP"}
|
||||
# intentionally not testing scratch memory ops
|
||||
MEM_FORMATS = {"VGLOBAL", "GLOBAL", "SMEM", "DS"}
|
||||
|
||||
def should_skip(op:Enum) -> bool: return (name:=op.name) in SKIP or any(sub in name for sub in SKIP_SUBSTR)
|
||||
|
||||
# ** named register assignments
|
||||
|
||||
# ALU operands
|
||||
ALU_VGPR_STRIDE = 16 # v[0], v[16], v[32], ... per ALU operand slot
|
||||
ALU_SGPR_STRIDE = 4 # s[0], s[4], s[8], ... per ALU operand slot
|
||||
|
||||
# memory address registers
|
||||
S_KERNARG_PTR = (0, 1)
|
||||
S_BUF_PTR = (2, 3)
|
||||
V_VADDR = (0, 1)
|
||||
V_DS_ADDR = 0
|
||||
|
||||
# memory data registers
|
||||
MEM_VGPR_BASE = 32 # v[32], v[48], ... for vdst/vdata/vsrc
|
||||
MEM_VGPR_STRIDE = 16 # spacing between memory data vgpr slots
|
||||
MEM_SGPR_BASE = 8 # s[8], s[10], ... for SMEM sdata
|
||||
MEM_SGPR_STRIDE = 2 # spacing between memory data sgpr slots
|
||||
|
||||
# ** create an ALU instruction based on the operands
|
||||
|
||||
def create_alu_inst(op:Enum, builder:functools.partial[Inst]) -> Inst:
|
||||
inst_cls, operands, slot = builder.func, OPERANDS[op], 0
|
||||
kwargs:dict[str, Reg|int] = {}
|
||||
for name, field in inst_cls._fields:
|
||||
if isinstance(field, (FixedBitField, EnumBitField)): continue
|
||||
nregs = max(1, operands[name][1] // 32) if name in operands else 1
|
||||
is_sreg = name in operands and "SREG" in str(operands[name][2])
|
||||
base_v, base_s = slot * ALU_VGPR_STRIDE, slot * ALU_SGPR_STRIDE
|
||||
if name == "sdst" and isinstance(field, SGPRField): reg = VCC_LO
|
||||
elif is_sreg and not isinstance(field, VGPRField): reg = VCC_LO
|
||||
elif isinstance(field, VGPRField): reg = v[base_v:base_v+nregs-1] if nregs > 1 else v[base_v]
|
||||
elif isinstance(field, SSrcField): reg = VCC_LO if nregs <= 2 else s[base_s:base_s+nregs-1] if nregs > 1 else s[base_s]
|
||||
elif isinstance(field, SGPRField): reg = s[base_s:base_s+nregs-1] if nregs > 1 else s[base_s]
|
||||
elif isinstance(field, SrcField): reg = v[base_v:base_v+nregs-1] if nregs > 1 else v[base_v]
|
||||
else: reg = None
|
||||
if reg is not None: kwargs[name] = reg; slot += 1
|
||||
elif isinstance(field, BitField): kwargs[name] = field.default
|
||||
return builder(**kwargs)
|
||||
|
||||
# ** create a memory instruction with pre set address registers
|
||||
|
||||
MEM_PRESET_REGS:dict[str, dict[str, Reg]] = {
|
||||
"VGLOBAL":{"saddr":s[S_BUF_PTR[0]:S_BUF_PTR[1]], "vaddr":v[V_VADDR[0]:V_VADDR[1]]},
|
||||
"GLOBAL":{"saddr":s[S_BUF_PTR[0]:S_BUF_PTR[1]], "addr":v[V_DS_ADDR]}, # addr is 32-bit offset when saddr is valid SGPR
|
||||
"DS":{"addr":v[V_DS_ADDR]},
|
||||
"SMEM":{"sbase":s[S_KERNARG_PTR[0]:S_KERNARG_PTR[1]], "soffset":NULL},
|
||||
}
|
||||
|
||||
def create_mem_inst(op:Enum, builder:functools.partial[Inst]) -> Inst:
|
||||
inst_cls, operands, field_map = builder.func, OPERANDS.get(op, {}), MEM_PRESET_REGS.get(builder.func.__name__, {})
|
||||
kwargs:dict[str, Reg|int] = {}
|
||||
vslot, sslot = 0, 0
|
||||
for name, field in inst_cls._fields:
|
||||
if isinstance(field, (FixedBitField, EnumBitField)): continue
|
||||
if name in field_map:
|
||||
kwargs[name] = field_map[name]
|
||||
continue
|
||||
nregs = max(1, operands[name][1] // 32) if name in operands else 1
|
||||
if isinstance(field, VGPRField):
|
||||
vi = MEM_VGPR_BASE + vslot * MEM_VGPR_STRIDE
|
||||
kwargs[name] = v[vi:vi+nregs-1] if nregs > 1 else v[vi]
|
||||
vslot += 1
|
||||
elif isinstance(field, (SGPRField, AlignedSGPRField, SBaseField)):
|
||||
si = MEM_SGPR_BASE + sslot * MEM_SGPR_STRIDE
|
||||
kwargs[name] = s[si:si+nregs-1] if nregs > 1 else s[si]
|
||||
sslot += 1
|
||||
elif isinstance(field, BitField): kwargs[name] = field.default
|
||||
return builder(**kwargs)
|
||||
|
||||
# ** collect all memory and ALU instructions from the ISA autogen
|
||||
|
||||
def collect_instructions() -> tuple[list[Inst], list[Inst], list[str]]:
|
||||
op_map:dict[Enum, functools.partial[Inst]] = {}
|
||||
for name, obj in inspect.getmembers(all_insts):
|
||||
if isinstance(obj, functools.partial) and len(obj.args) == 1: op_map[obj.args[0]] = obj
|
||||
alu_insts:list[Inst] = []
|
||||
mem_insts:list[Inst] = []
|
||||
skipped:list[str] = []
|
||||
for op_enum, builder in op_map.items():
|
||||
if should_skip(op_enum) or op_enum not in OPERANDS: skipped.append(op_enum.name); continue
|
||||
fmt = builder.func.__name__
|
||||
if fmt in ALU_FORMATS: alu_insts.append(create_alu_inst(op_enum, builder))
|
||||
elif fmt in MEM_FORMATS: mem_insts.append(create_mem_inst(op_enum, builder))
|
||||
return alu_insts, mem_insts, skipped
|
||||
|
||||
def exec_insts(insts:list):
|
||||
k = Kernel(arch)
|
||||
# ** prologue for global memory
|
||||
k.emit(s_load_b64(sdata=s[S_BUF_PTR[0]:S_BUF_PTR[1]], sbase=s[S_KERNARG_PTR[0]:S_KERNARG_PTR[1]], soffset=NULL))
|
||||
k.waitcnt(lgkm=0)
|
||||
k.emit(v_mov_b32_e32(v[V_VADDR[0]], 0))
|
||||
k.emit(v_mov_b32_e32(v[V_VADDR[1]], 0))
|
||||
# ** emit
|
||||
for inst in insts: k.emit(inst)
|
||||
k.emit(s_endpgm())
|
||||
# ** run
|
||||
NUM_THREADS, NUM_GRIDS, BUF_SIZE = 32, 1, 1024*1024
|
||||
def fxn(A:UOp, B:UOp, C:UOp) -> UOp:
|
||||
lidx, gidx = UOp.special(NUM_THREADS, "lidx0"), UOp.special(NUM_GRIDS, "gidx0")
|
||||
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=BUF_SIZE, addrspace=AddrSpace.LOCAL), (), "lds")
|
||||
sink = UOp.sink(A.base, B.base, C.base, lds, lidx, gidx, arg=KernelInfo(name="discover_ops"))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple(UOp(Ops.INS, arg=x) for x in k.finalize()))))
|
||||
A = Tensor.empty(BUF_SIZE, dtype=dtypes.uint8)
|
||||
B = Tensor.empty(1, dtype=dtypes.uint8)
|
||||
C = Tensor.empty(1, dtype=dtypes.uint8)
|
||||
Tensor.custom_kernel(A, B, C, fxn=fxn)[0].realize()
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
arch = Device[Device.DEFAULT].renderer.arch
|
||||
if arch.startswith("gfx12"):
|
||||
from tinygrad.runtime.autogen.amd.rdna4.ins import *
|
||||
import tinygrad.runtime.autogen.amd.rdna4.ins as all_insts
|
||||
elif arch.startswith("gfx11"):
|
||||
from tinygrad.runtime.autogen.amd.rdna3.ins import *
|
||||
import tinygrad.runtime.autogen.amd.rdna3.ins as all_insts
|
||||
# these don"t exist in RDNA3, only RDNA3.5 and above
|
||||
SKIP.update(["S_FMAAK_F32", "S_FMAMK_F32"])
|
||||
else:
|
||||
print(f"{arch} not supported yet")
|
||||
sys.exit(0)
|
||||
alu_insts, mem_insts, skipped = collect_instructions()
|
||||
print(f"collected {len(alu_insts)} ALU + {len(mem_insts)} memory instructions ({len(skipped)} skipped)")
|
||||
exec_insts(mem_insts+alu_insts)
|
||||
@@ -9,6 +9,7 @@ EXAMPLES = [
|
||||
"test/backend/test_custom_kernel.py TestCustomKernel.test_empty",
|
||||
"test/test_tiny.py TestTiny.test_plus",
|
||||
"test/test_tiny.py TestTiny.test_gemm",
|
||||
"extra/sqtt/examples/discover_ops.py"
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
BIN
extra/sqtt/examples/gfx1100/profile_py_run_0.pkl
Normal file
BIN
extra/sqtt/examples/gfx1100/profile_py_run_0.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx1100/profile_py_run_1.pkl
Normal file
BIN
extra/sqtt/examples/gfx1100/profile_py_run_1.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx1200/profile_py_run_0.pkl
Normal file
BIN
extra/sqtt/examples/gfx1200/profile_py_run_0.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx1200/profile_py_run_1.pkl
Normal file
BIN
extra/sqtt/examples/gfx1200/profile_py_run_1.pkl
Normal file
Binary file not shown.
@@ -134,7 +134,10 @@ class TestSQTTMatchesBinary(unittest.TestCase):
|
||||
def _test_bit_counts(self, layout: int):
|
||||
if not (tables := extract_bit_tables()): self.skipTest("rocprof-trace-decoder not installed")
|
||||
from tinygrad.renderer.amd.sqtt import PACKET_TYPES_RDNA3, PACKET_TYPES_RDNA4
|
||||
# rocprof's bit table says L4 type 7 (TS_DELTA_S8_W3) is 72 bits, but the actual decoder uses 64 bits
|
||||
skip = {(4, 7)}
|
||||
for type_id, pkt_cls in {3: PACKET_TYPES_RDNA3, 4: PACKET_TYPES_RDNA4}[layout].items():
|
||||
if (layout, type_id) in skip: continue
|
||||
with self.subTest(packet=pkt_cls.__name__):
|
||||
self.assertEqual(pkt_cls._size_nibbles * 4, tables[layout - 2][type_id]) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import unittest, pickle
|
||||
from typing import Iterator
|
||||
from pathlib import Path
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.helpers import DEBUG, OSX
|
||||
from tinygrad.renderer.amd.sqtt import print_packets, map_insts
|
||||
from tinygrad.runtime.autogen.amd.rdna3.ins import s_endpgm
|
||||
from test.amd.disasm import disasm
|
||||
@@ -10,7 +10,7 @@ from test.amd.disasm import disasm
|
||||
import tinygrad
|
||||
EXAMPLES_DIR = Path(tinygrad.__file__).parent.parent / "extra/sqtt/examples"
|
||||
|
||||
def rocprof_inst_traces_match(sqtt, prg, target):
|
||||
def rocprof_inst_traces_match(sqtt, prg, target, pass_rocprof_err=False):
|
||||
from tinygrad.viz.serve import amd_decode
|
||||
from extra.sqtt.roc import decode as roc_decode, InstExec
|
||||
addr_table = amd_decode(prg.lib, target)
|
||||
@@ -30,7 +30,7 @@ def rocprof_inst_traces_match(sqtt, prg, target):
|
||||
rocprof_inst = next(rwaves_iter[info.wave][0])
|
||||
ref_pc = rocprof_inst.pc-prg.base
|
||||
# always check pc matches
|
||||
assert ref_pc == info.pc, f"pc mismatch {ref_pc}:{disasm_map[rocprof_inst.pc]} != {info.pc}:{disasm(info.inst)}"
|
||||
assert ref_pc == info.pc or pass_rocprof_err, f"pc mismatch {ref_pc}:{disasm_map[rocprof_inst.pc]} != {info.pc}:{disasm(info.inst)}"
|
||||
# special handling for s_endpgm, it marks the wave completion.
|
||||
if info.inst == s_endpgm():
|
||||
completed_wave = list(rwaves_iter[info.wave].pop(0))
|
||||
@@ -67,7 +67,9 @@ class TestSQTTMapBase(unittest.TestCase):
|
||||
if not event.itrace: continue
|
||||
if event.kern not in kern_events: continue
|
||||
with self.subTest(example=name, kern=event.kern):
|
||||
passed_insts, n_waves, n_units = rocprof_inst_traces_match(event, kern_events[event.kern], target)
|
||||
# rocprof OSX has a bug for sopk decoding, linux rocprof works
|
||||
pass_rocprof_err = OSX and target == "gfx1200" and name.startswith("profile_py")
|
||||
passed_insts, n_waves, n_units = rocprof_inst_traces_match(event, kern_events[event.kern], target, pass_rocprof_err)
|
||||
if n_waves: print(f"{name}: passed for {passed_insts} instructions across {n_waves} waves scheduled on {n_units} wave units")
|
||||
|
||||
class TestSQTTMapRDNA3(TestSQTTMapBase): target = "gfx1100"
|
||||
|
||||
@@ -46,6 +46,7 @@ class InstOp(Enum):
|
||||
SMEM = 0x1
|
||||
JUMP = 0x3 # branch taken
|
||||
JUMP_NO = 0x4 # branch not taken
|
||||
CALL = 0x5 # s_call_b64
|
||||
MESSAGE = 0x9
|
||||
VALU_TRANS = 0xb # transcendental: exp, log, rcp, sqrt, sin, cos
|
||||
VALU_64_SHIFT = 0xd # 64-bit shifts: lshl, lshr, ashr
|
||||
@@ -72,8 +73,10 @@ class InstOp(Enum):
|
||||
|
||||
# LDS ops on traced SIMD
|
||||
LDS_LOAD = 0x29
|
||||
LDS_ATOMIC = 0x2a # ds_append, ds_consume, ds_store_addtid_b32
|
||||
LDS_STORE = 0x2b
|
||||
LDS_STORE_64 = 0x2c
|
||||
LDS_STORE_96 = 0x2d
|
||||
LDS_STORE_128 = 0x2e
|
||||
|
||||
# Memory ops on other SIMD (0x5x range)
|
||||
@@ -99,17 +102,27 @@ class InstOp(Enum):
|
||||
|
||||
class InstOpRDNA4(Enum):
|
||||
"""SQTT instruction operation types for RDNA4 (gfx1200). Different encoding from RDNA3."""
|
||||
# TODO: we need to do discovery of all of these from instructions
|
||||
SALU = 0x0
|
||||
JUMP = 0x1
|
||||
NEXT = 0x2
|
||||
MESSAGE = 0x4
|
||||
VALU_TRANS = 0x5
|
||||
VALU_64 = 0x6
|
||||
VALU_MAD64 = 0x7
|
||||
VINTERP = 0x9
|
||||
VALU_WMMA = 0x46
|
||||
VMEM = 0x10
|
||||
VMEM_128 = 0x11
|
||||
VMEM_STORE = 0x12
|
||||
VMEM_STORE_128 = 0x14
|
||||
VMEM_STORE_G96 = 0x13 # global_store_[b96,b128]
|
||||
LDS_LOAD = 0x14
|
||||
LDS_STORE = 0x15
|
||||
LDS_STORE_64 = 0x16
|
||||
LDS_STORE_128 = 0x17
|
||||
VALU_F64 = 0x49
|
||||
SALU_TRANS = 0x4c # transcendental with sgpr src/dst
|
||||
SALU_MUL = 0x4d # s_[mul,mulhi,mulk]
|
||||
SALU_MUL64 = 0x4e
|
||||
OTHER_VMEM = 0x5e
|
||||
OTHER_VMEM_STORE = 0x60
|
||||
|
||||
@@ -147,11 +160,6 @@ class TS_DELTA_S8_W3(PacketType):
|
||||
delta = bits[10:8]
|
||||
_padding = bits[63:11]
|
||||
|
||||
class TS_DELTA_S8_W3_RDNA4(PacketType): # Layout 4: 64->72 bits
|
||||
encoding = bits[6:0] == 0b0100001
|
||||
delta = bits[10:8]
|
||||
_padding = bits[71:11]
|
||||
|
||||
class TS_DELTA_S5_W3(PacketType):
|
||||
encoding = bits[4:0] == 0b00110
|
||||
delta = bits[7:5]
|
||||
@@ -363,7 +371,7 @@ PACKET_TYPES_RDNA3: dict[int, type[PacketType]] = {
|
||||
}
|
||||
PACKET_TYPES_RDNA4: dict[int, type[PacketType]] = {
|
||||
**PACKET_TYPES_RDNA3,
|
||||
7: TS_DELTA_S8_W3_RDNA4, 9: WAVESTART_RDNA4, 10: TS_DELTA_S5_W2_RDNA4, 11: WAVEALLOC_RDNA4,
|
||||
9: WAVESTART_RDNA4, 10: TS_DELTA_S5_W2_RDNA4, 11: WAVEALLOC_RDNA4,
|
||||
12: TS_DELTA_S5_W3_RDNA4, 13: PERF_RDNA4, 22: TS_DELTA_OR_MARK_RDNA4, 24: INST_RDNA4,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user