mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
sqtt: remove discover_ops script (#15279)
This commit is contained in:
@@ -1,148 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Run all ALU and memory instructions in the ISA
|
||||
import functools, inspect
|
||||
from enum import Enum
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AddrSpace
|
||||
from tinygrad.renderer.amd.dsl import Inst, Reg, OPERANDS, SrcField, VGPRField, SGPRField, SSrcField, SBaseField, AlignedSGPRField, BitField
|
||||
from tinygrad.renderer.amd.dsl import FixedBitField, EnumBitField, s, v, NULL, VCC_LO
|
||||
from extra.gemm.amd_asm_matmul import Kernel
|
||||
|
||||
# skip instructions that mutate wave state (PC, EXEC, allocations, signals)
|
||||
SKIP = {"S_SETPC_B64", "S_SWAPPC_B64", "S_RFE_B64", "S_BARRIER_SIGNAL_ISFIRST", "S_GET_BARRIER_STATE", "S_ALLOC_VGPR", "S_SLEEP_VAR", "S_GETPC_B64",
|
||||
"S_SENDMSG_RTN_B32", "S_SENDMSG_RTN_B64"}
|
||||
# skip barriers, s_waits, wrap level atomics, and ray tracing (bvh)
|
||||
SKIP_SUBSTR = ["SAVEEXEC", "CMPX", "WREXEC", "MOVREL", "ATOMIC", "S_BUFFER_", "S_ATC_PROBE", "BARRIER", "S_WAITCNT", "BVH",
|
||||
"DS_CMPSTORE_RTN", "DS_WRAP_RTN_B32", "DS_ORDERED_COUNT", "DS_GWS", "GS_REG", "GLOBAL_LOAD_LDS", "GLOBAL_STORE_BLOCK"]
|
||||
|
||||
ALU_FORMATS = {"VOP1", "VOP1_LIT", "VOP1_SDST", "VOP2", "VOP2_LIT", "VOP3", "VOP3_SDST", "VOP3SD", "VOP3P", "VOP3P_MFMA", "VOP3PX2",
|
||||
"VOPC", "SOP1", "SOP1_LIT", "SOP2", "SOP2_LIT", "SOPC", "SOPC_LIT", "SOPK", "SOPK_LIT", "VINTERP"}
|
||||
# intentionally not testing scratch memory ops
|
||||
MEM_FORMATS = {"VGLOBAL", "GLOBAL", "SMEM", "DS"}
|
||||
|
||||
def should_skip(op:Enum) -> bool: return (name:=op.name) in SKIP or any(sub in name for sub in SKIP_SUBSTR)
|
||||
|
||||
# ** named register assignments
|
||||
|
||||
# ALU operands
|
||||
ALU_VGPR_STRIDE = 16 # v[0], v[16], v[32], ... per ALU operand slot
|
||||
ALU_SGPR_STRIDE = 4 # s[0], s[4], s[8], ... per ALU operand slot
|
||||
|
||||
# memory address registers
|
||||
S_KERNARG_PTR = (0, 1)
|
||||
S_BUF_PTR = (2, 3)
|
||||
V_VADDR = (0, 1)
|
||||
V_DS_ADDR = 0
|
||||
|
||||
# memory data registers
|
||||
MEM_VGPR_BASE = 32 # v[32], v[48], ... for vdst/vdata/vsrc
|
||||
MEM_VGPR_STRIDE = 16 # spacing between memory data vgpr slots
|
||||
MEM_SGPR_BASE = 8 # s[8], s[10], ... for SMEM sdata
|
||||
MEM_SGPR_STRIDE = 2 # spacing between memory data sgpr slots
|
||||
|
||||
# ** create an ALU instruction based on the operands
|
||||
|
||||
def create_alu_inst(op:Enum, builder:functools.partial[Inst]) -> Inst:
|
||||
inst_cls, operands, slot = builder.func, OPERANDS[op], 0
|
||||
kwargs:dict[str, Reg|int] = {}
|
||||
for name, field in inst_cls._fields:
|
||||
if isinstance(field, (FixedBitField, EnumBitField)): continue
|
||||
nregs = max(1, operands[name][1] // 32) if name in operands else 1
|
||||
is_sreg = name in operands and "SREG" in str(operands[name][2])
|
||||
base_v, base_s = slot * ALU_VGPR_STRIDE, slot * ALU_SGPR_STRIDE
|
||||
if name == "sdst" and isinstance(field, SGPRField): reg = VCC_LO
|
||||
elif is_sreg and not isinstance(field, VGPRField): reg = VCC_LO
|
||||
elif isinstance(field, VGPRField): reg = v[base_v:base_v+nregs-1] if nregs > 1 else v[base_v]
|
||||
elif isinstance(field, SSrcField): reg = VCC_LO if nregs <= 2 else s[base_s:base_s+nregs-1] if nregs > 1 else s[base_s]
|
||||
elif isinstance(field, SGPRField): reg = s[base_s:base_s+nregs-1] if nregs > 1 else s[base_s]
|
||||
elif isinstance(field, SrcField): reg = v[base_v:base_v+nregs-1] if nregs > 1 else v[base_v]
|
||||
else: reg = None
|
||||
if reg is not None: kwargs[name] = reg; slot += 1
|
||||
elif isinstance(field, BitField): kwargs[name] = field.default
|
||||
return builder(**kwargs)
|
||||
|
||||
# ** create a memory instruction with pre set address registers
|
||||
|
||||
MEM_PRESET_REGS:dict[str, dict[str, Reg]] = {
|
||||
"VGLOBAL":{"saddr":s[S_BUF_PTR[0]:S_BUF_PTR[1]], "vaddr":v[V_VADDR[0]:V_VADDR[1]]},
|
||||
"GLOBAL":{"saddr":s[S_BUF_PTR[0]:S_BUF_PTR[1]], "addr":v[V_DS_ADDR]}, # addr is 32-bit offset when saddr is valid SGPR
|
||||
"DS":{"addr":v[V_DS_ADDR]},
|
||||
"SMEM":{"sbase":s[S_KERNARG_PTR[0]:S_KERNARG_PTR[1]], "soffset":NULL},
|
||||
}
|
||||
|
||||
def create_mem_inst(op:Enum, builder:functools.partial[Inst]) -> Inst:
|
||||
inst_cls, operands, field_map = builder.func, OPERANDS.get(op, {}), MEM_PRESET_REGS.get(builder.func.__name__, {})
|
||||
kwargs:dict[str, Reg|int] = {}
|
||||
vslot, sslot = 0, 0
|
||||
for name, field in inst_cls._fields:
|
||||
if isinstance(field, (FixedBitField, EnumBitField)): continue
|
||||
if name in field_map:
|
||||
kwargs[name] = field_map[name]
|
||||
continue
|
||||
nregs = max(1, operands[name][1] // 32) if name in operands else 1
|
||||
if isinstance(field, VGPRField):
|
||||
vi = MEM_VGPR_BASE + vslot * MEM_VGPR_STRIDE
|
||||
kwargs[name] = v[vi:vi+nregs-1] if nregs > 1 else v[vi]
|
||||
vslot += 1
|
||||
elif isinstance(field, (SGPRField, AlignedSGPRField, SBaseField)):
|
||||
si = MEM_SGPR_BASE + sslot * MEM_SGPR_STRIDE
|
||||
kwargs[name] = s[si:si+nregs-1] if nregs > 1 else s[si]
|
||||
sslot += 1
|
||||
elif isinstance(field, BitField): kwargs[name] = field.default
|
||||
return builder(**kwargs)
|
||||
|
||||
# ** collect all memory and ALU instructions from the ISA autogen
|
||||
|
||||
def collect_instructions() -> tuple[list[Inst], list[Inst], list[str]]:
|
||||
op_map:dict[Enum, functools.partial[Inst]] = {}
|
||||
for name, obj in inspect.getmembers(all_insts):
|
||||
if isinstance(obj, functools.partial) and len(obj.args) == 1: op_map[obj.args[0]] = obj
|
||||
alu_insts:list[Inst] = []
|
||||
mem_insts:list[Inst] = []
|
||||
skipped:list[str] = []
|
||||
for op_enum, builder in op_map.items():
|
||||
if should_skip(op_enum) or op_enum not in OPERANDS: skipped.append(op_enum.name); continue
|
||||
fmt = builder.func.__name__
|
||||
if fmt in ALU_FORMATS: alu_insts.append(create_alu_inst(op_enum, builder))
|
||||
elif fmt in MEM_FORMATS: mem_insts.append(create_mem_inst(op_enum, builder))
|
||||
return alu_insts, mem_insts, skipped
|
||||
|
||||
def exec_insts(insts:list):
|
||||
k = Kernel(arch)
|
||||
# ** prologue for global memory
|
||||
k.emit(s_load_b64(sdata=s[S_BUF_PTR[0]:S_BUF_PTR[1]], sbase=s[S_KERNARG_PTR[0]:S_KERNARG_PTR[1]], soffset=NULL))
|
||||
k.waitcnt(lgkm=0)
|
||||
k.emit(v_mov_b32_e32(v[V_VADDR[0]], 0))
|
||||
k.emit(v_mov_b32_e32(v[V_VADDR[1]], 0))
|
||||
# ** emit
|
||||
for inst in insts: k.emit(inst)
|
||||
k.emit(s_endpgm())
|
||||
# ** run
|
||||
NUM_THREADS, NUM_GRIDS, BUF_SIZE = 32, 1, 1024*1024
|
||||
def fxn(A:UOp, B:UOp, C:UOp) -> UOp:
|
||||
lidx, gidx = UOp.special(NUM_THREADS, "lidx0"), UOp.special(NUM_GRIDS, "gidx0")
|
||||
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=BUF_SIZE, addrspace=AddrSpace.LOCAL), (), "lds")
|
||||
sink = UOp.sink(A.base, B.base, C.base, lds, lidx, gidx, arg=KernelInfo(name="discover_ops"))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple(UOp(Ops.INS, arg=x) for x in k.finalize()))))
|
||||
A = Tensor.empty(BUF_SIZE, dtype=dtypes.uint8)
|
||||
B = Tensor.empty(1, dtype=dtypes.uint8)
|
||||
C = Tensor.empty(1, dtype=dtypes.uint8)
|
||||
Tensor.custom_kernel(A, B, C, fxn=fxn)[0].realize()
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
arch = Device[Device.DEFAULT].renderer.arch
|
||||
if arch.startswith("gfx12"):
|
||||
from tinygrad.runtime.autogen.amd.rdna4.ins import *
|
||||
import tinygrad.runtime.autogen.amd.rdna4.ins as all_insts
|
||||
elif arch.startswith("gfx11"):
|
||||
from tinygrad.runtime.autogen.amd.rdna3.ins import *
|
||||
import tinygrad.runtime.autogen.amd.rdna3.ins as all_insts
|
||||
# these don"t exist in RDNA3, only RDNA3.5 and above
|
||||
SKIP.update(["S_FMAAK_F32", "S_FMAMK_F32"])
|
||||
else:
|
||||
print(f"{arch} not supported yet")
|
||||
sys.exit(0)
|
||||
alu_insts, mem_insts, skipped = collect_instructions()
|
||||
print(f"collected {len(alu_insts)} ALU + {len(mem_insts)} memory instructions ({len(skipped)} skipped)")
|
||||
exec_insts(mem_insts+alu_insts)
|
||||
@@ -9,7 +9,6 @@ EXAMPLES = {
|
||||
"empty":"test/backend/test_custom_kernel.py TestCustomKernel.test_empty",
|
||||
"plus":"test/test_tiny.py TestTiny.test_plus",
|
||||
"gemm":"-c \"from tinygrad import Tensor; (Tensor.empty(N:=32, N)@Tensor.empty(N, N)).realize()\"",
|
||||
"ops":"extra/sqtt/examples/discover_ops.py"
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -2,7 +2,7 @@
|
||||
import unittest, pickle
|
||||
from typing import Iterator
|
||||
from pathlib import Path
|
||||
from tinygrad.helpers import DEBUG, OSX, getenv, temp
|
||||
from tinygrad.helpers import DEBUG, getenv, temp
|
||||
from tinygrad.renderer.amd.sqtt import print_packets, map_insts
|
||||
from tinygrad.runtime.autogen.amd.rdna3.ins import s_endpgm
|
||||
from tinygrad.viz.serve import sqtt_timeline
|
||||
@@ -11,7 +11,7 @@ from test.amd.disasm import disasm
|
||||
import tinygrad
|
||||
EXAMPLES_DIR = Path(tinygrad.__file__).parent.parent / "extra/sqtt/examples"
|
||||
|
||||
def rocprof_inst_traces_match(sqtt, prg, target, pass_rocprof_err=False):
|
||||
def rocprof_inst_traces_match(sqtt, prg, target):
|
||||
from tinygrad.viz.serve import amd_decode
|
||||
from extra.sqtt.roc import decode as roc_decode, InstExec
|
||||
addr_table = amd_decode(prg.lib, target)
|
||||
@@ -31,7 +31,7 @@ def rocprof_inst_traces_match(sqtt, prg, target, pass_rocprof_err=False):
|
||||
rocprof_inst = next(rwaves_iter[info.wave][0])
|
||||
ref_pc = rocprof_inst.pc-prg.base
|
||||
# always check pc matches
|
||||
assert ref_pc == info.pc or pass_rocprof_err, f"pc mismatch {ref_pc}:{disasm_map[rocprof_inst.pc]} != {info.pc}:{disasm(info.inst)}"
|
||||
assert ref_pc == info.pc, f"pc mismatch {ref_pc}:{disasm_map[rocprof_inst.pc]} != {info.pc}:{disasm(info.inst)}"
|
||||
# special handling for s_endpgm, it marks the wave completion.
|
||||
if info.inst == s_endpgm():
|
||||
completed_wave = list(rwaves_iter[info.wave].pop(0))
|
||||
@@ -68,9 +68,7 @@ class TestSQTTMapBase(unittest.TestCase):
|
||||
if not event.itrace: continue
|
||||
if event.kern not in kern_events: continue
|
||||
with self.subTest(example=name, kern=event.kern):
|
||||
# rocprof OSX has a bug for sopk decoding, linux rocprof works
|
||||
pass_rocprof_err = OSX and target == "gfx1200" and name.startswith("profile_ops")
|
||||
passed_insts, n_waves, n_units = rocprof_inst_traces_match(event, kern_events[event.kern], target, pass_rocprof_err)
|
||||
passed_insts, n_waves, n_units = rocprof_inst_traces_match(event, kern_events[event.kern], target)
|
||||
if n_waves: print(f"{name}: passed for {passed_insts} instructions across {n_waves} waves scheduled on {n_units} wave units")
|
||||
|
||||
def test_sqtt_timeline(self):
|
||||
|
||||
Reference in New Issue
Block a user