mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
viz: start displaying pma (#14848)
* viz: start displaying pma * s * work * colors * cleaner * max packets * fine * work * pma * diff cleanup
This commit is contained in:
@@ -192,7 +192,9 @@ const waveColor = (op) => {
|
||||
const colorScheme = {TINY:new Map([["Schedule","#1b5745"],["get_program","#1d2e62"],["compile","#63b0cd"],["DEFAULT","#354f52"]]),
|
||||
DEFAULT:["#2b2e39", "#2c2f3a", "#31343f", "#323544", "#2d303a", "#2e313c", "#343746", "#353847", "#3c4050", "#404459", "#444862", "#4a4e65"],
|
||||
BUFFER:["#342483", "#3E2E94", "#4938A4", "#5442B4", "#5E4CC2", "#674FCA"], SIMD:new Map([["OCC", "#101725"], ["INST", "#0A2042"]]),
|
||||
WAVE:waveColor, VMEMEXEC:waveColor, ALUEXEC:waveColor}
|
||||
GPC:new Map([["NONE","#1a7a2e"],["MEMORY_DEPENDENCY","#8b1a00"],["EXEC_DEPENDENCY","#006b6b"],["INST_FETCH","#7a7a00"],["SYNC","#6b006b"],
|
||||
["PIPE_BUSY","#7a4a00"],["MEMORY_THROTTLE","#5c0000"],["CONSTANT_MEMORY","#1a3d7a"],["NOT_SELECTED","#2e2e3a"],["OTHER","#4a4a55"],
|
||||
["SLEEPING","#1a1a2a"],["DEFAULT","#3a3a45"]]), WAVE:waveColor, VMEMEXEC:waveColor, ALUEXEC:waveColor}
|
||||
const cycleColors = (lst, i) => lst[i%lst.length];
|
||||
|
||||
const rescaleTrack = (source, tid, k) => {
|
||||
@@ -826,7 +828,7 @@ async function main() {
|
||||
}
|
||||
// timeline with cycles on the x axis
|
||||
if (ret instanceof ArrayBuffer) {
|
||||
opts = {heightScale:0.5, hideLabels:true, levelKey:(e) => parseInt(e.name.split(" ")[1].split(":")[1]), colorByName:step.name.includes("PKTS")};
|
||||
opts = {heightScale:0.5, hideLabels:true, levelKey:step.name.includes("PKTS") ? (e) => parseInt(e.name.split(" ")[1].split(":")[1]) : null, colorByName:ckey.includes("pkts")};
|
||||
return renderProfiler(ckey, "clk", opts);
|
||||
}
|
||||
metadata.innerHTML = "";
|
||||
|
||||
@@ -404,6 +404,7 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_
|
||||
if (d:=ev.device.split(":")[0]) == "AMD":
|
||||
device_decoders[d] = load_counters
|
||||
amdgpu_targets[d] = f"gfx{unwrap(ev.props)['gfx_target_version']//1000}"
|
||||
if d == "NV": device_decoders[d] = load_pma_counters
|
||||
# load device specific counters
|
||||
for fxn in device_decoders.values(): fxn(profile)
|
||||
# map events per device
|
||||
@@ -431,6 +432,35 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_
|
||||
index = json.dumps({"strings":list(scache), "dtypeSize":dtype_size, "markers":[{"ts":rel_ts(e.ts, start_ts), **e.arg} for e in markers]}).encode()
|
||||
return struct.pack("<IQII", rel_ts(unwrap(end_ts), start_ts), max(peaks,default=0), len(index), len(ret))+index+b"".join(ret)
|
||||
|
||||
# ** PMA counters
|
||||
|
||||
def load_pma_counters(profile:list) -> None:
|
||||
steps:list[dict] = []
|
||||
sm_version = {e.device:e.props.get("sm_version", 0x800) for e in profile if isinstance(e, ProfileDeviceEvent) and e.props is not None}
|
||||
run_number:dict[str, int] = {}
|
||||
for e in profile:
|
||||
if type(e).__name__ == "ProfilePMAEvent":
|
||||
run_number[e.kern] = run_num = run_number.get(e.kern, 0)+1
|
||||
steps.append(create_step(f"PMA {e.kern}"+(f"n{run_num}" if run_num>1 else ""), ("/prg-pma-pkts", len(ctxs), len(steps)),
|
||||
data=(e.blob, sm_version[e.device])))
|
||||
if steps: ctxs.append({"name":"All Counters", "steps":steps})
|
||||
|
||||
def pma_timeline(blob:bytes, sm_version:int) -> list[ProfileEvent]:
|
||||
from extra.nv_pma.decode import decode, decode_tpc_id
|
||||
ret:list[ProfileEvent] = []
|
||||
rows:dict[str, None] = {}
|
||||
tpc_count:dict[int, int] = {}
|
||||
# assume every sample is 32 cycles
|
||||
cycles_per_sample = 32
|
||||
for s, tpc_id in decode(blob, sm_version):
|
||||
if len(ret) > getenv("MAX_SQTT_PKTS", 50_000): break
|
||||
gpc, tpc, sm = decode_tpc_id(tpc_id)
|
||||
tpc_count[tpc_id] = (n:=tpc_count.get(tpc_id,0)) + 1
|
||||
rows.setdefault(row:=f"GPC:{gpc} TPC:{tpc} SM:{sm} WAVE:{s.wave_id}")
|
||||
ret.append(ProfileRangeEvent(row, TracingKey(s.stall_reason.name, ret=f"pc=0x{s.pc_offset:06x} active={s.active}"),
|
||||
Decimal(n*cycles_per_sample), Decimal((n+1)*cycles_per_sample)))
|
||||
return [ProfilePointEvent(r, "start", r, ts=Decimal(0)) for r in rows]+ret
|
||||
|
||||
# ** Assembly static analyzers
|
||||
|
||||
def get_stdout(f: Callable) -> str:
|
||||
@@ -585,6 +615,12 @@ def get_render(query:str) -> dict:
|
||||
summary = [{"label":"Total Cycles", "value":w.end_time-w.begin_time}, {"label":"SE", "value":w.se}, {"label":"CU", "value":w.cu},
|
||||
{"label":"SIMD", "value":w.simd}, {"label":"Wave ID", "value":w.wave_id}, {"label":"Run number", "value":data["run_number"]}]
|
||||
return {"rows":[tuple(v.values()) for v in rows.values()], "cols":columns, "metadata":[summary], "ref":ref_map.get(data["prg"].name)}
|
||||
if fmt == "prg-pma-pkts":
|
||||
ret = {}
|
||||
with soft_err(lambda err:ret.update(err)):
|
||||
if (events:=get_profile(pma_timeline(*data), sort_fn=row_tuple)): ret = {"value":events, "content_type":"application/octet-stream"}
|
||||
else: ret = {"src":"No PMA samples found."}
|
||||
return ret
|
||||
return data
|
||||
|
||||
# ** HTTP server
|
||||
|
||||
Reference in New Issue
Block a user