command line interface for sqtt viz (#13891)

* command line interface for sqtt viz

* cleanup

* api surface area

* this confuses the llms

* document
This commit is contained in:
qazal
2025-12-30 12:33:21 +09:00
committed by GitHub
parent ab58926b00
commit d7e1f26e3d
4 changed files with 55 additions and 28 deletions

60
extra/sqtt/roc.py Normal file → Executable file
View File

@@ -1,11 +1,9 @@
import ctypes, pathlib, argparse, pickle, re, functools, dataclasses, itertools, threading #!/usr/bin/env python3
import ctypes, pathlib, argparse, pickle, dataclasses, threading
from typing import Generator from typing import Generator
from tinygrad.helpers import temp, unwrap, DEBUG from tinygrad.helpers import temp, unwrap, DEBUG
from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileProgramEvent from tinygrad.runtime.ops_amd import ProfileSQTTEvent
from tinygrad.runtime.ops_amd import ProfileSQTTEvent, ProfilePMCEvent from tinygrad.runtime.autogen import rocprof
from tinygrad.runtime.autogen import llvm, rocprof
from tinygrad.runtime.support.elf import elf_loader
from tinygrad.viz.serve import llvm_disasm
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class InstExec: class InstExec:
@@ -117,26 +115,52 @@ def decode(sqtt_evs:list[ProfileSQTTEvent], disasms:dict[str, dict[int, tuple[st
def worker(): def worker():
try: rocprof.rocprof_trace_decoder_parse_data(copy_cb, trace_cb, isa_cb, None) try: rocprof.rocprof_trace_decoder_parse_data(copy_cb, trace_cb, isa_cb, None)
except AttributeError as e: raise RuntimeError("Failed to find rocprof-trace-decoder. Run sudo ./extra/sqtt/install_sqtt_decoder.py to install") from e except AttributeError as e:
raise RuntimeError("Failed to find rocprof-trace-decoder. Run sudo ./extra/sqtt/install_sqtt_decoder.py to install") from e
(t:=threading.Thread(target=worker, daemon=True)).start() (t:=threading.Thread(target=worker, daemon=True)).start()
t.join() t.join()
return ROCParseCtx return ROCParseCtx
def print_pmc(events:list[ProfilePMCEvent]) -> None: def print_data(data:dict) -> None:
from tinygrad.viz.serve import unpack_pmc
from tabulate import tabulate from tabulate import tabulate
for e in events: # plaintext
print("**", e.kern) if "src" in data: print(data["src"])
data = unpack_pmc(e) # table format
print(tabulate([r[:-1] for r in data["rows"]], headers=data["cols"], tablefmt="github")) elif "cols" in data:
print(tabulate([r[:len(data["cols"])] for r in data["rows"]], headers=data["cols"], tablefmt="github"))
def main() -> None:
import tinygrad.viz.serve as viz
viz.ctxs = []
if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--profile', type=pathlib.Path, help='Path to profile', default=pathlib.Path(temp("profile.pkl", append_user=True))) parser.add_argument('--profile', type=pathlib.Path, metavar="PATH", help='Path to profile (optional file, default: latest profile)',
default=pathlib.Path(temp("profile.pkl", append_user=True)))
parser.add_argument('--kernel', type=str, default=None, metavar="NAME", help='Kernel to focus on (optional name, default: all kernels)')
parser.add_argument('-n', type=int, default=3, metavar="NUM", help='Max traces to print (optional number, default: 3 traces)')
args = parser.parse_args() args = parser.parse_args()
with args.profile.open("rb") as f: profile = pickle.load(f) with args.profile.open("rb") as f: profile = pickle.load(f)
#rctx = decode(profile, disasm)
#print('SQTT:', rctx.inst_execs.keys())
print_pmc([ev for ev in profile if isinstance(ev, ProfilePMCEvent)]) viz.get_profile(profile)
# List all kernels
if args.kernel is None:
for c in viz.ctxs:
print(c["name"])
for s in c["steps"]: print(" "+s["name"])
return None
# Find kernel trace
trace = next((c for c in viz.ctxs if c["name"] == f"Exec {args.kernel}"), None)
if not trace: raise RuntimeError(f"no matching trace for {args.kernel}")
n = 0
for s in trace["steps"]:
print(s["name"])
data = viz.get_render(s["query"])
print_data(data)
n += 1
if n > args.n: break
if __name__ == "__main__":
main()

View File

@@ -197,7 +197,7 @@ AMD_CC, CPU_CC, NV_CC, CUDA_CC = ContextVar("AMD_CC", ""), ContextVar("CPU_CC",
QCOM_CC = ContextVar("QCOM_CC", "") QCOM_CC = ContextVar("QCOM_CC", "")
# VIZ implies PROFILE, but you can run PROFILE without VIZ # VIZ implies PROFILE, but you can run PROFILE without VIZ
VIZ = ContextVar("VIZ", 0) VIZ = ContextVar("VIZ", 0)
PROFILE = ContextVar("PROFILE", VIZ.value) PROFILE = ContextVar("PROFILE", abs(VIZ.value))
SPEC = ContextVar("SPEC", 1) SPEC = ContextVar("SPEC", 1)
# TODO: disable by default due to speed # TODO: disable by default due to speed
IGNORE_OOB = ContextVar("IGNORE_OOB", 1) IGNORE_OOB = ContextVar("IGNORE_OOB", 1)

View File

@@ -1176,7 +1176,7 @@ if TRACK_MATCH_STATS or PROFILE:
print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}") print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}")
pickle.dump(RewriteTrace(tracked_keys, tracked_ctxs, uop_fields), f) pickle.dump(RewriteTrace(tracked_keys, tracked_ctxs, uop_fields), f)
if VIZ > 0: return launch_viz("VIZ", temp("rewrites.pkl", append_user=True)) if VIZ > 0: return launch_viz("VIZ", temp("rewrites.pkl", append_user=True))
if getenv("PRINT_MATCH_STATS", TRACK_MATCH_STATS.value): if getenv("PRINT_MATCH_STATS", TRACK_MATCH_STATS.value and VIZ.value>=0):
ret = [0,0,0.0,0.0] ret = [0,0,0.0,0.0]
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]): for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]):
loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}" loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}"

View File

@@ -266,7 +266,7 @@ def load_counters(profile:list[ProfileEvent]) -> None:
# run our decoder on startup, we don't use this since it only works on gfx11 # run our decoder on startup, we don't use this since it only works on gfx11
from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets
for e in sqtt: parse_sqtt_print_packets(e.blob) for e in sqtt: parse_sqtt_print_packets(e.blob)
ctxs.append({"name":f"Exec {name} n{run_number[k]}", "steps":steps}) ctxs.append({"name":f"Exec {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps})
# ** SQTT OCC only unpacks wave start, end time and SIMD location # ** SQTT OCC only unpacks wave start, end time and SIMD location
@@ -424,7 +424,9 @@ def amdgpu_cfg(lib:bytes, target:int) -> dict:
# ** Main render function to get the complete details about a trace event # ** Main render function to get the complete details about a trace event
def get_render(i:int, j:int, fmt:str) -> dict: def get_render(query:str) -> dict:
url = urlparse(query)
i, j, fmt = get_int(qs:=parse_qs(url.query), "ctx"), get_int(qs, "step"), url.path.lstrip("/")
data = ctxs[i]["steps"][j]["data"] data = ctxs[i]["steps"][j]["data"]
if fmt == "graph-rewrites": return {"value":get_full_rewrite(trace.rewrites[i][j]), "content_type":"text/event-stream"} if fmt == "graph-rewrites": return {"value":get_full_rewrite(trace.rewrites[i][j]), "content_type":"text/event-stream"}
if fmt == "uops": return {"src":get_stdout(lambda: print_uops(data.uops or [])), "lang":"txt"} if fmt == "uops": return {"src":get_stdout(lambda: print_uops(data.uops or [])), "lang":"txt"}
@@ -504,16 +506,17 @@ class Handler(HTTPRequestHandler):
if url.path.endswith(".js"): content_type = "application/javascript" if url.path.endswith(".js"): content_type = "application/javascript"
if url.path.endswith(".css"): content_type = "text/css" if url.path.endswith(".css"): content_type = "text/css"
except FileNotFoundError: status_code = 404 except FileNotFoundError: status_code = 404
elif (query:=parse_qs(url.query)):
render_src = get_render(get_int(query, "ctx"), get_int(query, "step"), url.path.lstrip("/"))
if "content_type" in render_src: ret, content_type = render_src["value"], render_src["content_type"]
else: ret, content_type = json.dumps(render_src).encode(), "application/json"
if content_type == "text/event-stream": return self.stream_json(render_src["value"])
elif url.path == "/ctxs": elif url.path == "/ctxs":
lst = [{**c, "steps":[{k:v for k, v in s.items() if k != "data"} for s in c["steps"]]} for c in ctxs] lst = [{**c, "steps":[{k:v for k, v in s.items() if k != "data"} for s in c["steps"]]} for c in ctxs]
ret, content_type = json.dumps(lst).encode(), "application/json" ret, content_type = json.dumps(lst).encode(), "application/json"
elif url.path == "/get_profile" and profile_ret: ret, content_type = profile_ret, "application/octet-stream" elif url.path == "/get_profile" and profile_ret: ret, content_type = profile_ret, "application/octet-stream"
else: status_code = 404 else:
if not (render_src:=get_render(self.path)): status_code = 404
else:
if "content_type" in render_src: ret, content_type = render_src["value"], render_src["content_type"]
else: ret, content_type = json.dumps(render_src).encode(), "application/json"
if content_type == "text/event-stream": return self.stream_json(render_src["value"])
return self.send_data(ret, content_type, status_code) return self.send_data(ret, content_type, status_code)