diff --git a/extra/sqtt/examples/discover_ops.py b/extra/sqtt/examples/discover_ops.py new file mode 100644 index 0000000000..76f37e3f0f --- /dev/null +++ b/extra/sqtt/examples/discover_ops.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +# Run all ALU and memory instructions in the ISA +import functools, inspect +from enum import Enum +from tinygrad import Tensor, Device, dtypes +from tinygrad.uop.ops import UOp, Ops, KernelInfo, AddrSpace +from tinygrad.renderer.amd.dsl import Inst, Reg, OPERANDS, SrcField, VGPRField, SGPRField, SSrcField, SBaseField, AlignedSGPRField, BitField +from tinygrad.renderer.amd.dsl import FixedBitField, EnumBitField, s, v, NULL, VCC_LO +from extra.gemm.amd_asm_matmul import Kernel + +# skip instructions that mutate wave state (PC, EXEC, allocations, signals) +SKIP = {"S_SETPC_B64", "S_SWAPPC_B64", "S_RFE_B64", "S_BARRIER_SIGNAL_ISFIRST", "S_GET_BARRIER_STATE", "S_ALLOC_VGPR", "S_SLEEP_VAR", "S_GETPC_B64", + "S_SENDMSG_RTN_B32", "S_SENDMSG_RTN_B64"} +# skip barriers, s_waits, wrap level atomics, and ray tracing (bvh) +SKIP_SUBSTR = ["SAVEEXEC", "CMPX", "WREXEC", "MOVREL", "ATOMIC", "S_BUFFER_", "S_ATC_PROBE", "BARRIER", "S_WAITCNT", "BVH", + "DS_CMPSTORE_RTN", "DS_WRAP_RTN_B32", "DS_ORDERED_COUNT", "DS_GWS", "GS_REG", "GLOBAL_LOAD_LDS", "GLOBAL_STORE_BLOCK"] + +ALU_FORMATS = {"VOP1", "VOP1_LIT", "VOP1_SDST", "VOP2", "VOP2_LIT", "VOP3", "VOP3_SDST", "VOP3SD", "VOP3P", "VOP3P_MFMA", "VOP3PX2", + "VOPC", "SOP1", "SOP1_LIT", "SOP2", "SOP2_LIT", "SOPC", "SOPC_LIT", "SOPK", "SOPK_LIT", "VINTERP"} +# intentionally not testing scratch memory ops +MEM_FORMATS = {"VGLOBAL", "GLOBAL", "SMEM", "DS"} + +def should_skip(op:Enum) -> bool: return (name:=op.name) in SKIP or any(sub in name for sub in SKIP_SUBSTR) + +# ** named register assignments + +# ALU operands +ALU_VGPR_STRIDE = 16 # v[0], v[16], v[32], ... per ALU operand slot +ALU_SGPR_STRIDE = 4 # s[0], s[4], s[8], ... per ALU operand slot + +# memory address registers +S_KERNARG_PTR = (0, 1) +S_BUF_PTR = (2, 3) +V_VADDR = (0, 1) +V_DS_ADDR = 0 + +# memory data registers +MEM_VGPR_BASE = 32 # v[32], v[48], ... for vdst/vdata/vsrc +MEM_VGPR_STRIDE = 16 # spacing between memory data vgpr slots +MEM_SGPR_BASE = 8 # s[8], s[10], ... for SMEM sdata +MEM_SGPR_STRIDE = 2 # spacing between memory data sgpr slots + +# ** create an ALU instruction based on the operands + +def create_alu_inst(op:Enum, builder:functools.partial[Inst]) -> Inst: + inst_cls, operands, slot = builder.func, OPERANDS[op], 0 + kwargs:dict[str, Reg|int] = {} + for name, field in inst_cls._fields: + if isinstance(field, (FixedBitField, EnumBitField)): continue + nregs = max(1, operands[name][1] // 32) if name in operands else 1 + is_sreg = name in operands and "SREG" in str(operands[name][2]) + base_v, base_s = slot * ALU_VGPR_STRIDE, slot * ALU_SGPR_STRIDE + if name == "sdst" and isinstance(field, SGPRField): reg = VCC_LO + elif is_sreg and not isinstance(field, VGPRField): reg = VCC_LO + elif isinstance(field, VGPRField): reg = v[base_v:base_v+nregs-1] if nregs > 1 else v[base_v] + elif isinstance(field, SSrcField): reg = VCC_LO if nregs <= 2 else s[base_s:base_s+nregs-1] if nregs > 1 else s[base_s] + elif isinstance(field, SGPRField): reg = s[base_s:base_s+nregs-1] if nregs > 1 else s[base_s] + elif isinstance(field, SrcField): reg = v[base_v:base_v+nregs-1] if nregs > 1 else v[base_v] + else: reg = None + if reg is not None: kwargs[name] = reg; slot += 1 + elif isinstance(field, BitField): kwargs[name] = field.default + return builder(**kwargs) + +# ** create a memory instruction with pre set address registers + +MEM_PRESET_REGS:dict[str, dict[str, Reg]] = { + "VGLOBAL":{"saddr":s[S_BUF_PTR[0]:S_BUF_PTR[1]], "vaddr":v[V_VADDR[0]:V_VADDR[1]]}, + "GLOBAL":{"saddr":s[S_BUF_PTR[0]:S_BUF_PTR[1]], "addr":v[V_DS_ADDR]}, # addr is 32-bit offset when saddr is valid SGPR + "DS":{"addr":v[V_DS_ADDR]}, + "SMEM":{"sbase":s[S_KERNARG_PTR[0]:S_KERNARG_PTR[1]], "soffset":NULL}, +} + +def create_mem_inst(op:Enum, builder:functools.partial[Inst]) -> Inst: + inst_cls, operands, field_map = builder.func, OPERANDS.get(op, {}), MEM_PRESET_REGS.get(builder.func.__name__, {}) + kwargs:dict[str, Reg|int] = {} + vslot, sslot = 0, 0 + for name, field in inst_cls._fields: + if isinstance(field, (FixedBitField, EnumBitField)): continue + if name in field_map: + kwargs[name] = field_map[name] + continue + nregs = max(1, operands[name][1] // 32) if name in operands else 1 + if isinstance(field, VGPRField): + vi = MEM_VGPR_BASE + vslot * MEM_VGPR_STRIDE + kwargs[name] = v[vi:vi+nregs-1] if nregs > 1 else v[vi] + vslot += 1 + elif isinstance(field, (SGPRField, AlignedSGPRField, SBaseField)): + si = MEM_SGPR_BASE + sslot * MEM_SGPR_STRIDE + kwargs[name] = s[si:si+nregs-1] if nregs > 1 else s[si] + sslot += 1 + elif isinstance(field, BitField): kwargs[name] = field.default + return builder(**kwargs) + +# ** collect all memory and ALU instructions from the ISA autogen + +def collect_instructions() -> tuple[list[Inst], list[Inst], list[str]]: + op_map:dict[Enum, functools.partial[Inst]] = {} + for name, obj in inspect.getmembers(all_insts): + if isinstance(obj, functools.partial) and len(obj.args) == 1: op_map[obj.args[0]] = obj + alu_insts:list[Inst] = [] + mem_insts:list[Inst] = [] + skipped:list[str] = [] + for op_enum, builder in op_map.items(): + if should_skip(op_enum) or op_enum not in OPERANDS: skipped.append(op_enum.name); continue + fmt = builder.func.__name__ + if fmt in ALU_FORMATS: alu_insts.append(create_alu_inst(op_enum, builder)) + elif fmt in MEM_FORMATS: mem_insts.append(create_mem_inst(op_enum, builder)) + return alu_insts, mem_insts, skipped + +def exec_insts(insts:list): + k = Kernel(arch) + # ** prologue for global memory + k.emit(s_load_b64(sdata=s[S_BUF_PTR[0]:S_BUF_PTR[1]], sbase=s[S_KERNARG_PTR[0]:S_KERNARG_PTR[1]], soffset=NULL)) + k.waitcnt(lgkm=0) + k.emit(v_mov_b32_e32(v[V_VADDR[0]], 0)) + k.emit(v_mov_b32_e32(v[V_VADDR[1]], 0)) + # ** emit + for inst in insts: k.emit(inst) + k.emit(s_endpgm()) + # ** run + NUM_THREADS, NUM_GRIDS, BUF_SIZE = 32, 1, 1024*1024 + def fxn(A:UOp, B:UOp, C:UOp) -> UOp: + lidx, gidx = UOp.special(NUM_THREADS, "lidx0"), UOp.special(NUM_GRIDS, "gidx0") + lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=BUF_SIZE, addrspace=AddrSpace.LOCAL), (), "lds") + sink = UOp.sink(A.base, B.base, C.base, lds, lidx, gidx, arg=KernelInfo(name="discover_ops")) + return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple(UOp(Ops.INS, arg=x) for x in k.finalize())))) + A = Tensor.empty(BUF_SIZE, dtype=dtypes.uint8) + B = Tensor.empty(1, dtype=dtypes.uint8) + C = Tensor.empty(1, dtype=dtypes.uint8) + Tensor.custom_kernel(A, B, C, fxn=fxn)[0].realize() + +if __name__ == "__main__": + import sys + arch = Device[Device.DEFAULT].renderer.arch + if arch.startswith("gfx12"): + from tinygrad.runtime.autogen.amd.rdna4.ins import * + import tinygrad.runtime.autogen.amd.rdna4.ins as all_insts + elif arch.startswith("gfx11"): + from tinygrad.runtime.autogen.amd.rdna3.ins import * + import tinygrad.runtime.autogen.amd.rdna3.ins as all_insts + # these don"t exist in RDNA3, only RDNA3.5 and above + SKIP.update(["S_FMAAK_F32", "S_FMAMK_F32"]) + else: + print(f"{arch} not supported yet") + sys.exit(0) + alu_insts, mem_insts, skipped = collect_instructions() + print(f"collected {len(alu_insts)} ALU + {len(mem_insts)} memory instructions ({len(skipped)} skipped)") + exec_insts(mem_insts+alu_insts) diff --git a/extra/sqtt/examples/generate_examples.py b/extra/sqtt/examples/generate_examples.py index 48453764d0..0e815f928f 100644 --- a/extra/sqtt/examples/generate_examples.py +++ b/extra/sqtt/examples/generate_examples.py @@ -9,6 +9,7 @@ EXAMPLES = [ "test/backend/test_custom_kernel.py TestCustomKernel.test_empty", "test/test_tiny.py TestTiny.test_plus", "test/test_tiny.py TestTiny.test_gemm", + "extra/sqtt/examples/discover_ops.py" ] if __name__ == "__main__": diff --git a/extra/sqtt/examples/gfx1100/profile_py_run_0.pkl b/extra/sqtt/examples/gfx1100/profile_py_run_0.pkl new file mode 100644 index 0000000000..804ed3bb43 Binary files /dev/null and b/extra/sqtt/examples/gfx1100/profile_py_run_0.pkl differ diff --git a/extra/sqtt/examples/gfx1100/profile_py_run_1.pkl b/extra/sqtt/examples/gfx1100/profile_py_run_1.pkl new file mode 100644 index 0000000000..e67773b609 Binary files /dev/null and b/extra/sqtt/examples/gfx1100/profile_py_run_1.pkl differ diff --git a/extra/sqtt/examples/gfx1200/profile_py_run_0.pkl b/extra/sqtt/examples/gfx1200/profile_py_run_0.pkl new file mode 100644 index 0000000000..d098e4d09f Binary files /dev/null and b/extra/sqtt/examples/gfx1200/profile_py_run_0.pkl differ diff --git a/extra/sqtt/examples/gfx1200/profile_py_run_1.pkl b/extra/sqtt/examples/gfx1200/profile_py_run_1.pkl new file mode 100644 index 0000000000..d28e3537e7 Binary files /dev/null and b/extra/sqtt/examples/gfx1200/profile_py_run_1.pkl differ diff --git a/test/amd/test_sqtt_tables.py b/test/amd/test_sqtt_tables.py index 8ced6b4ead..6f44095743 100644 --- a/test/amd/test_sqtt_tables.py +++ b/test/amd/test_sqtt_tables.py @@ -134,7 +134,10 @@ class TestSQTTMatchesBinary(unittest.TestCase): def _test_bit_counts(self, layout: int): if not (tables := extract_bit_tables()): self.skipTest("rocprof-trace-decoder not installed") from tinygrad.renderer.amd.sqtt import PACKET_TYPES_RDNA3, PACKET_TYPES_RDNA4 + # rocprof's bit table says L4 type 7 (TS_DELTA_S8_W3) is 72 bits, but the actual decoder uses 64 bits + skip = {(4, 7)} for type_id, pkt_cls in {3: PACKET_TYPES_RDNA3, 4: PACKET_TYPES_RDNA4}[layout].items(): + if (layout, type_id) in skip: continue with self.subTest(packet=pkt_cls.__name__): self.assertEqual(pkt_cls._size_nibbles * 4, tables[layout - 2][type_id]) # type: ignore[attr-defined] diff --git a/test/amd/test_sqttmap.py b/test/amd/test_sqttmap.py index e9c5d42753..162d31e653 100644 --- a/test/amd/test_sqttmap.py +++ b/test/amd/test_sqttmap.py @@ -2,7 +2,7 @@ import unittest, pickle from typing import Iterator from pathlib import Path -from tinygrad.helpers import DEBUG +from tinygrad.helpers import DEBUG, OSX from tinygrad.renderer.amd.sqtt import print_packets, map_insts from tinygrad.runtime.autogen.amd.rdna3.ins import s_endpgm from test.amd.disasm import disasm @@ -10,7 +10,7 @@ from test.amd.disasm import disasm import tinygrad EXAMPLES_DIR = Path(tinygrad.__file__).parent.parent / "extra/sqtt/examples" -def rocprof_inst_traces_match(sqtt, prg, target): +def rocprof_inst_traces_match(sqtt, prg, target, pass_rocprof_err=False): from tinygrad.viz.serve import amd_decode from extra.sqtt.roc import decode as roc_decode, InstExec addr_table = amd_decode(prg.lib, target) @@ -30,7 +30,7 @@ def rocprof_inst_traces_match(sqtt, prg, target): rocprof_inst = next(rwaves_iter[info.wave][0]) ref_pc = rocprof_inst.pc-prg.base # always check pc matches - assert ref_pc == info.pc, f"pc mismatch {ref_pc}:{disasm_map[rocprof_inst.pc]} != {info.pc}:{disasm(info.inst)}" + assert ref_pc == info.pc or pass_rocprof_err, f"pc mismatch {ref_pc}:{disasm_map[rocprof_inst.pc]} != {info.pc}:{disasm(info.inst)}" # special handling for s_endpgm, it marks the wave completion. if info.inst == s_endpgm(): completed_wave = list(rwaves_iter[info.wave].pop(0)) @@ -67,7 +67,9 @@ class TestSQTTMapBase(unittest.TestCase): if not event.itrace: continue if event.kern not in kern_events: continue with self.subTest(example=name, kern=event.kern): - passed_insts, n_waves, n_units = rocprof_inst_traces_match(event, kern_events[event.kern], target) + # rocprof OSX has a bug for sopk decoding, linux rocprof works + pass_rocprof_err = OSX and target == "gfx1200" and name.startswith("profile_py") + passed_insts, n_waves, n_units = rocprof_inst_traces_match(event, kern_events[event.kern], target, pass_rocprof_err) if n_waves: print(f"{name}: passed for {passed_insts} instructions across {n_waves} waves scheduled on {n_units} wave units") class TestSQTTMapRDNA3(TestSQTTMapBase): target = "gfx1100" diff --git a/tinygrad/renderer/amd/sqtt.py b/tinygrad/renderer/amd/sqtt.py index a2a2f9c599..6f8c1226e9 100644 --- a/tinygrad/renderer/amd/sqtt.py +++ b/tinygrad/renderer/amd/sqtt.py @@ -46,6 +46,7 @@ class InstOp(Enum): SMEM = 0x1 JUMP = 0x3 # branch taken JUMP_NO = 0x4 # branch not taken + CALL = 0x5 # s_call_b64 MESSAGE = 0x9 VALU_TRANS = 0xb # transcendental: exp, log, rcp, sqrt, sin, cos VALU_64_SHIFT = 0xd # 64-bit shifts: lshl, lshr, ashr @@ -72,8 +73,10 @@ class InstOp(Enum): # LDS ops on traced SIMD LDS_LOAD = 0x29 + LDS_ATOMIC = 0x2a # ds_append, ds_consume, ds_store_addtid_b32 LDS_STORE = 0x2b LDS_STORE_64 = 0x2c + LDS_STORE_96 = 0x2d LDS_STORE_128 = 0x2e # Memory ops on other SIMD (0x5x range) @@ -99,17 +102,27 @@ class InstOp(Enum): class InstOpRDNA4(Enum): """SQTT instruction operation types for RDNA4 (gfx1200). Different encoding from RDNA3.""" - # TODO: we need to do discovery of all of these from instructions SALU = 0x0 JUMP = 0x1 NEXT = 0x2 MESSAGE = 0x4 + VALU_TRANS = 0x5 VALU_64 = 0x6 + VALU_MAD64 = 0x7 + VINTERP = 0x9 VALU_WMMA = 0x46 VMEM = 0x10 VMEM_128 = 0x11 VMEM_STORE = 0x12 - VMEM_STORE_128 = 0x14 + VMEM_STORE_G96 = 0x13 # global_store_[b96,b128] + LDS_LOAD = 0x14 + LDS_STORE = 0x15 + LDS_STORE_64 = 0x16 + LDS_STORE_128 = 0x17 + VALU_F64 = 0x49 + SALU_TRANS = 0x4c # transcendental with sgpr src/dst + SALU_MUL = 0x4d # s_[mul,mulhi,mulk] + SALU_MUL64 = 0x4e OTHER_VMEM = 0x5e OTHER_VMEM_STORE = 0x60 @@ -147,11 +160,6 @@ class TS_DELTA_S8_W3(PacketType): delta = bits[10:8] _padding = bits[63:11] -class TS_DELTA_S8_W3_RDNA4(PacketType): # Layout 4: 64->72 bits - encoding = bits[6:0] == 0b0100001 - delta = bits[10:8] - _padding = bits[71:11] - class TS_DELTA_S5_W3(PacketType): encoding = bits[4:0] == 0b00110 delta = bits[7:5] @@ -363,7 +371,7 @@ PACKET_TYPES_RDNA3: dict[int, type[PacketType]] = { } PACKET_TYPES_RDNA4: dict[int, type[PacketType]] = { **PACKET_TYPES_RDNA3, - 7: TS_DELTA_S8_W3_RDNA4, 9: WAVESTART_RDNA4, 10: TS_DELTA_S5_W2_RDNA4, 11: WAVEALLOC_RDNA4, + 9: WAVESTART_RDNA4, 10: TS_DELTA_S5_W2_RDNA4, 11: WAVEALLOC_RDNA4, 12: TS_DELTA_S5_W3_RDNA4, 13: PERF_RDNA4, 22: TS_DELTA_OR_MARK_RDNA4, 24: INST_RDNA4, }