mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
viz: faster startup time with SQTT=1 (#13337)
* roc.py cleanups * direct append * viz index cleanup * simd row details * add kernel arg * late instructions decode * more instruction decode to sep server request * 200ms startup, 6 second to waves timeline * sort units * creating new http paths is easy now * instructions unpacker * min diff, use hyphens * summary table
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import ctypes, pathlib, argparse, pickle, re, functools, dataclasses, itertools, threading
|
||||
from typing import Generator
|
||||
from tinygrad.helpers import temp, unwrap, DEBUG
|
||||
from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileProgramEvent
|
||||
from tinygrad.runtime.ops_amd import ProfileSQTTEvent, ProfilePMCEvent
|
||||
@@ -31,7 +32,7 @@ def llvm_disasm(arch:str, lib:bytes) -> dict[int, tuple[str, int]]:
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class InstExec:
|
||||
typ:str
|
||||
inst:str
|
||||
pc:int
|
||||
stall:int
|
||||
dur:int
|
||||
time:int
|
||||
@@ -44,7 +45,13 @@ class WaveExec:
|
||||
se:int
|
||||
begin_time:int
|
||||
end_time:int
|
||||
insts:list[InstExec]
|
||||
insts:bytearray
|
||||
def unpack_insts(self) -> Generator[InstExec, None, None]:
|
||||
sz = ctypes.sizeof(struct:=rocprof.rocprofiler_thread_trace_decoder_inst_t)
|
||||
insts_array = (struct*(len(self.insts)//sz)).from_buffer(self.insts)
|
||||
for inst in insts_array:
|
||||
inst_typ = rocprof.enum_rocprofiler_thread_trace_decoder_inst_category_t.get(inst.category)
|
||||
yield InstExec(inst_typ, inst.pc.address, inst.stall, inst.duration, inst.time)
|
||||
|
||||
class _ROCParseCtx:
|
||||
def __init__(self, dev_evs:dict[str, ProfileDeviceEvent], sqtt_evs:list[ProfileSQTTEvent], prog_evs:list[ProfileProgramEvent]):
|
||||
@@ -70,18 +77,12 @@ class _ROCParseCtx:
|
||||
def on_wave_ev(self, ev:rocprof.rocprofiler_thread_trace_decoder_wave_t):
|
||||
if DEBUG >= 5: print("WAVE", ev.wave_id, self.active_se, ev.cu, ev.simd, ev.contexts, ev.begin_time, ev.end_time)
|
||||
|
||||
inst_execs:list[InstExec] = []
|
||||
disasm = self.disasms[unwrap(self.active_kern)]
|
||||
for j in range(ev.instructions_size):
|
||||
inst_ev = ev.instructions_array[j]
|
||||
inst_typ = rocprof.enum_rocprofiler_thread_trace_decoder_inst_category_t.get(inst_ev.category)
|
||||
inst_disasm = disasm[unwrap(inst_ev.pc.address)][0]
|
||||
inst_execs.append(InstExec(inst_typ, inst_disasm, inst_ev.stall, inst_ev.duration, inst_ev.time))
|
||||
if DEBUG >= 8: print(inst_execs[-1])
|
||||
insts_blob = bytearray(sz:=ev.instructions_size * ctypes.sizeof(rocprof.rocprofiler_thread_trace_decoder_inst_t))
|
||||
ctypes.memmove((ctypes.c_char * sz).from_buffer(insts_blob), ev.instructions_array, sz)
|
||||
|
||||
if ev.instructions_size > 0:
|
||||
self.inst_execs.setdefault(unwrap(self.active_kern), []).append(WaveExec(ev.wave_id, ev.cu, ev.simd, unwrap(self.active_se), ev.begin_time,
|
||||
ev.end_time, inst_execs))
|
||||
ev.end_time, insts_blob))
|
||||
|
||||
def decode(profile:list[ProfileEvent]) -> _ROCParseCtx:
|
||||
dev_events:dict[str, ProfileDeviceEvent] = {}
|
||||
|
||||
@@ -225,13 +225,7 @@ def load_sqtt(profile:list[ProfileEvent]) -> None:
|
||||
if not rctx.inst_execs: return err("EMPTY SQTT OUTPUT", f"{len(sqtt_events)} SQTT events recorded, none got decoded")
|
||||
steps:list[dict] = []
|
||||
for name,waves in rctx.inst_execs.items():
|
||||
# Idle: The total time gap between the completion of previous instruction and the beginning of the current instruction.
|
||||
# The idle time can be caused by:
|
||||
# * Arbiter loss
|
||||
# * Source or destination register dependency
|
||||
# * Instruction cache miss
|
||||
# Stall: The total number of cycles the hardware pipe couldn't issue an instruction.
|
||||
# Duration: Total latency in cycles, defined as "Stall time + Issue time" for gfx9 or "Stall time + Execute time" for gfx10+.
|
||||
disasm = rctx.disasms[name]
|
||||
units:dict[str, int] = {}
|
||||
events:list[ProfileEvent] = []
|
||||
wave_execs:dict[str, dict] = {}
|
||||
@@ -239,19 +233,13 @@ def load_sqtt(profile:list[ProfileEvent]) -> None:
|
||||
if (row:=f"SE:{w.se} CU:{w.cu} SIMD:{w.simd} WAVE:{w.wave_id}") not in units: units[row] = 0
|
||||
units[row] += 1
|
||||
events.append(ProfileRangeEvent(row, f"N:{units[row]}", Decimal(w.begin_time), Decimal(w.end_time)))
|
||||
rows, prev_instr = [], w.begin_time
|
||||
for i,e in enumerate(w.insts):
|
||||
rows.append((e.inst, e.time, max(0, e.time-prev_instr), e.dur, e.stall, str(e.typ).split("_")[-1]))
|
||||
prev_instr = max(prev_instr, e.time + e.dur)
|
||||
summary = [{"label":"Total Cycles", "value":w.end_time-w.begin_time}, {"label":"SE", "value":w.se}, {"label":"CU", "value":w.cu},
|
||||
{"label":"SIMD", "value":w.simd}, {"label":"Wave ID", "value":w.wave_id}, {"label":"Run number", "value":units[row]}]
|
||||
wave_execs[f"{row} N:{units[row]}"] = {"rows":rows, "cols":["Instruction", "Clk", "Idle", "Duration", "Stall", "Type"], "summary":summary}
|
||||
wave_execs[f"{row} N:{units[row]}"] = {"wave":w, "disasm":disasm, "run_number":units[row]}
|
||||
# gather and sort all wave execs of this kernel
|
||||
events = [ProfilePointEvent(unit, "start", unit, ts=Decimal(0)) for unit in units]+events
|
||||
kernel = trace.keys[r].ret if (r:=ref_map.get(name)) else None
|
||||
steps.append(create_step(kernel.name if kernel is not None else name, ("/counters", len(ctxs), len(steps)),
|
||||
{"value":get_profile(events, sort_fn=row_tuple), "content_type":"application/octet-stream"}, depth=1))
|
||||
for k in sorted(wave_execs, key=row_tuple): steps.append(create_step(k, ("/counters", len(ctxs), len(steps)), wave_execs[k], depth=2))
|
||||
for k in sorted(wave_execs, key=row_tuple): steps.append(create_step(k, ("/sqtt-insts", len(ctxs), len(steps)), wave_execs[k], depth=2))
|
||||
ctxs.append({"name":"Counters", "steps":steps})
|
||||
|
||||
def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]|None=None) -> bytes|None:
|
||||
@@ -330,6 +318,24 @@ def get_render(i:int, j:int, fmt:str) -> dict:
|
||||
return get_llvm_mca(disasm_str, ctypes.string_at(llvm.LLVMGetTargetMachineTriple(tm:=compiler.target_machine)).decode(),
|
||||
ctypes.string_at(llvm.LLVMGetTargetMachineCPU(tm)).decode())
|
||||
return {"src":disasm_str, "lang":"x86asm"}
|
||||
if fmt == "sqtt-insts":
|
||||
columns = ["Instruction", "Clk", "Idle", "Duration", "Stall", "Type"]
|
||||
# Idle: The total time gap between the completion of previous instruction and the beginning of the current instruction.
|
||||
# The idle time can be caused by:
|
||||
# * Arbiter loss
|
||||
# * Source or destination register dependency
|
||||
# * Instruction cache miss
|
||||
# Stall: The total number of cycles the hardware pipe couldn't issue an instruction.
|
||||
# Duration: Total latency in cycles, defined as "Stall time + Issue time" for gfx9 or "Stall time + Execute time" for gfx10+.
|
||||
prev_instr = (w:=data["wave"]).begin_time
|
||||
pc_to_inst = data["disasm"]
|
||||
rows:list[tuple] = []
|
||||
for e in w.unpack_insts():
|
||||
rows.append((pc_to_inst[e.pc][0], e.time, max(0, e.time-prev_instr), e.dur, e.stall, str(e.typ).split("_")[-1]))
|
||||
prev_instr = max(prev_instr, e.time + e.dur)
|
||||
summary = [{"label":"Total Cycles", "value":w.end_time-w.begin_time}, {"label":"SE", "value":w.se}, {"label":"CU", "value":w.cu},
|
||||
{"label":"SIMD", "value":w.simd}, {"label":"Wave ID", "value":w.wave_id}, {"label":"Run number", "value":data["run_number"]}]
|
||||
return {"rows":rows, "cols":columns, "summary":summary}
|
||||
return data
|
||||
|
||||
# ** HTTP server
|
||||
|
||||
Reference in New Issue
Block a user