viz: start displaying pma (#14848)

* viz: start displaying pma

* s

* work

* colors

* cleaner

* max packets

* fine

* work

* pma

* diff cleanup
This commit is contained in:
qazal
2026-02-18 13:22:32 +08:00
committed by GitHub
parent d5636fba90
commit a3d516c4b5
2 changed files with 40 additions and 2 deletions

View File

@@ -192,7 +192,9 @@ const waveColor = (op) => {
const colorScheme = {TINY:new Map([["Schedule","#1b5745"],["get_program","#1d2e62"],["compile","#63b0cd"],["DEFAULT","#354f52"]]),
DEFAULT:["#2b2e39", "#2c2f3a", "#31343f", "#323544", "#2d303a", "#2e313c", "#343746", "#353847", "#3c4050", "#404459", "#444862", "#4a4e65"],
BUFFER:["#342483", "#3E2E94", "#4938A4", "#5442B4", "#5E4CC2", "#674FCA"], SIMD:new Map([["OCC", "#101725"], ["INST", "#0A2042"]]),
WAVE:waveColor, VMEMEXEC:waveColor, ALUEXEC:waveColor}
GPC:new Map([["NONE","#1a7a2e"],["MEMORY_DEPENDENCY","#8b1a00"],["EXEC_DEPENDENCY","#006b6b"],["INST_FETCH","#7a7a00"],["SYNC","#6b006b"],
["PIPE_BUSY","#7a4a00"],["MEMORY_THROTTLE","#5c0000"],["CONSTANT_MEMORY","#1a3d7a"],["NOT_SELECTED","#2e2e3a"],["OTHER","#4a4a55"],
["SLEEPING","#1a1a2a"],["DEFAULT","#3a3a45"]]), WAVE:waveColor, VMEMEXEC:waveColor, ALUEXEC:waveColor}
const cycleColors = (lst, i) => lst[i%lst.length];
const rescaleTrack = (source, tid, k) => {
@@ -826,7 +828,7 @@ async function main() {
}
// timeline with cycles on the x axis
if (ret instanceof ArrayBuffer) {
opts = {heightScale:0.5, hideLabels:true, levelKey:(e) => parseInt(e.name.split(" ")[1].split(":")[1]), colorByName:step.name.includes("PKTS")};
opts = {heightScale:0.5, hideLabels:true, levelKey:step.name.includes("PKTS") ? (e) => parseInt(e.name.split(" ")[1].split(":")[1]) : null, colorByName:ckey.includes("pkts")};
return renderProfiler(ckey, "clk", opts);
}
metadata.innerHTML = "";

View File

@@ -404,6 +404,7 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_
if (d:=ev.device.split(":")[0]) == "AMD":
device_decoders[d] = load_counters
amdgpu_targets[d] = f"gfx{unwrap(ev.props)['gfx_target_version']//1000}"
if d == "NV": device_decoders[d] = load_pma_counters
# load device specific counters
for fxn in device_decoders.values(): fxn(profile)
# map events per device
@@ -431,6 +432,35 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_
index = json.dumps({"strings":list(scache), "dtypeSize":dtype_size, "markers":[{"ts":rel_ts(e.ts, start_ts), **e.arg} for e in markers]}).encode()
return struct.pack("<IQII", rel_ts(unwrap(end_ts), start_ts), max(peaks,default=0), len(index), len(ret))+index+b"".join(ret)
# ** PMA counters
def load_pma_counters(profile:list) -> None:
steps:list[dict] = []
sm_version = {e.device:e.props.get("sm_version", 0x800) for e in profile if isinstance(e, ProfileDeviceEvent) and e.props is not None}
run_number:dict[str, int] = {}
for e in profile:
if type(e).__name__ == "ProfilePMAEvent":
run_number[e.kern] = run_num = run_number.get(e.kern, 0)+1
steps.append(create_step(f"PMA {e.kern}"+(f"n{run_num}" if run_num>1 else ""), ("/prg-pma-pkts", len(ctxs), len(steps)),
data=(e.blob, sm_version[e.device])))
if steps: ctxs.append({"name":"All Counters", "steps":steps})
def pma_timeline(blob:bytes, sm_version:int) -> list[ProfileEvent]:
from extra.nv_pma.decode import decode, decode_tpc_id
ret:list[ProfileEvent] = []
rows:dict[str, None] = {}
tpc_count:dict[int, int] = {}
# assume every sample is 32 cycles
cycles_per_sample = 32
for s, tpc_id in decode(blob, sm_version):
if len(ret) > getenv("MAX_SQTT_PKTS", 50_000): break
gpc, tpc, sm = decode_tpc_id(tpc_id)
tpc_count[tpc_id] = (n:=tpc_count.get(tpc_id,0)) + 1
rows.setdefault(row:=f"GPC:{gpc} TPC:{tpc} SM:{sm} WAVE:{s.wave_id}")
ret.append(ProfileRangeEvent(row, TracingKey(s.stall_reason.name, ret=f"pc=0x{s.pc_offset:06x} active={s.active}"),
Decimal(n*cycles_per_sample), Decimal((n+1)*cycles_per_sample)))
return [ProfilePointEvent(r, "start", r, ts=Decimal(0)) for r in rows]+ret
# ** Assembly static analyzers
def get_stdout(f: Callable) -> str:
@@ -585,6 +615,12 @@ def get_render(query:str) -> dict:
summary = [{"label":"Total Cycles", "value":w.end_time-w.begin_time}, {"label":"SE", "value":w.se}, {"label":"CU", "value":w.cu},
{"label":"SIMD", "value":w.simd}, {"label":"Wave ID", "value":w.wave_id}, {"label":"Run number", "value":data["run_number"]}]
return {"rows":[tuple(v.values()) for v in rows.values()], "cols":columns, "metadata":[summary], "ref":ref_map.get(data["prg"].name)}
if fmt == "prg-pma-pkts":
ret = {}
with soft_err(lambda err:ret.update(err)):
if (events:=get_profile(pma_timeline(*data), sort_fn=row_tuple)): ret = {"value":events, "content_type":"application/octet-stream"}
else: ret = {"src":"No PMA samples found."}
return ret
return data
# ** HTTP server