diff --git a/extra/sqtt/roc.py b/extra/sqtt/roc.py old mode 100644 new mode 100755 index 225456b073..7085d5d621 --- a/extra/sqtt/roc.py +++ b/extra/sqtt/roc.py @@ -1,11 +1,9 @@ -import ctypes, pathlib, argparse, pickle, re, functools, dataclasses, itertools, threading +#!/usr/bin/env python3 +import ctypes, pathlib, argparse, pickle, dataclasses, threading from typing import Generator from tinygrad.helpers import temp, unwrap, DEBUG -from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileProgramEvent -from tinygrad.runtime.ops_amd import ProfileSQTTEvent, ProfilePMCEvent -from tinygrad.runtime.autogen import llvm, rocprof -from tinygrad.runtime.support.elf import elf_loader -from tinygrad.viz.serve import llvm_disasm +from tinygrad.runtime.ops_amd import ProfileSQTTEvent +from tinygrad.runtime.autogen import rocprof @dataclasses.dataclass(frozen=True) class InstExec: @@ -117,26 +115,52 @@ def decode(sqtt_evs:list[ProfileSQTTEvent], disasms:dict[str, dict[int, tuple[st def worker(): try: rocprof.rocprof_trace_decoder_parse_data(copy_cb, trace_cb, isa_cb, None) - except AttributeError as e: raise RuntimeError("Failed to find rocprof-trace-decoder. Run sudo ./extra/sqtt/install_sqtt_decoder.py to install") from e + except AttributeError as e: + raise RuntimeError("Failed to find rocprof-trace-decoder. Run sudo ./extra/sqtt/install_sqtt_decoder.py to install") from e (t:=threading.Thread(target=worker, daemon=True)).start() t.join() return ROCParseCtx -def print_pmc(events:list[ProfilePMCEvent]) -> None: - from tinygrad.viz.serve import unpack_pmc +def print_data(data:dict) -> None: from tabulate import tabulate - for e in events: - print("**", e.kern) - data = unpack_pmc(e) - print(tabulate([r[:-1] for r in data["rows"]], headers=data["cols"], tablefmt="github")) + # plaintext + if "src" in data: print(data["src"]) + # table format + elif "cols" in data: + print(tabulate([r[:len(data["cols"])] for r in data["rows"]], headers=data["cols"], tablefmt="github")) + +def main() -> None: + import tinygrad.viz.serve as viz + viz.ctxs = [] -if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--profile', type=pathlib.Path, help='Path to profile', default=pathlib.Path(temp("profile.pkl", append_user=True))) + parser.add_argument('--profile', type=pathlib.Path, metavar="PATH", help='Path to profile (optional file, default: latest profile)', + default=pathlib.Path(temp("profile.pkl", append_user=True))) + parser.add_argument('--kernel', type=str, default=None, metavar="NAME", help='Kernel to focus on (optional name, default: all kernels)') + parser.add_argument('-n', type=int, default=3, metavar="NUM", help='Max traces to print (optional number, default: 3 traces)') args = parser.parse_args() with args.profile.open("rb") as f: profile = pickle.load(f) - #rctx = decode(profile, disasm) - #print('SQTT:', rctx.inst_execs.keys()) - print_pmc([ev for ev in profile if isinstance(ev, ProfilePMCEvent)]) + viz.get_profile(profile) + + # List all kernels + if args.kernel is None: + for c in viz.ctxs: + print(c["name"]) + for s in c["steps"]: print(" "+s["name"]) + return None + + # Find kernel trace + trace = next((c for c in viz.ctxs if c["name"] == f"Exec {args.kernel}"), None) + if not trace: raise RuntimeError(f"no matching trace for {args.kernel}") + n = 0 + for s in trace["steps"]: + print(s["name"]) + data = viz.get_render(s["query"]) + print_data(data) + n += 1 + if n > args.n: break + +if __name__ == "__main__": + main() diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 7514bd34ae..e997985ac3 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -197,7 +197,7 @@ AMD_CC, CPU_CC, NV_CC, CUDA_CC = ContextVar("AMD_CC", ""), ContextVar("CPU_CC", QCOM_CC = ContextVar("QCOM_CC", "") # VIZ implies PROFILE, but you can run PROFILE without VIZ VIZ = ContextVar("VIZ", 0) -PROFILE = ContextVar("PROFILE", VIZ.value) +PROFILE = ContextVar("PROFILE", abs(VIZ.value)) SPEC = ContextVar("SPEC", 1) # TODO: disable by default due to speed IGNORE_OOB = ContextVar("IGNORE_OOB", 1) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 5e10812348..e23b8e5943 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -1176,7 +1176,7 @@ if TRACK_MATCH_STATS or PROFILE: print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}") pickle.dump(RewriteTrace(tracked_keys, tracked_ctxs, uop_fields), f) if VIZ > 0: return launch_viz("VIZ", temp("rewrites.pkl", append_user=True)) - if getenv("PRINT_MATCH_STATS", TRACK_MATCH_STATS.value): + if getenv("PRINT_MATCH_STATS", TRACK_MATCH_STATS.value and VIZ.value>=0): ret = [0,0,0.0,0.0] for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]): loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}" diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index f0c937dbf2..104c019a46 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -266,7 +266,7 @@ def load_counters(profile:list[ProfileEvent]) -> None: # run our decoder on startup, we don't use this since it only works on gfx11 from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets for e in sqtt: parse_sqtt_print_packets(e.blob) - ctxs.append({"name":f"Exec {name} n{run_number[k]}", "steps":steps}) + ctxs.append({"name":f"Exec {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps}) # ** SQTT OCC only unpacks wave start, end time and SIMD location @@ -424,7 +424,9 @@ def amdgpu_cfg(lib:bytes, target:int) -> dict: # ** Main render function to get the complete details about a trace event -def get_render(i:int, j:int, fmt:str) -> dict: +def get_render(query:str) -> dict: + url = urlparse(query) + i, j, fmt = get_int(qs:=parse_qs(url.query), "ctx"), get_int(qs, "step"), url.path.lstrip("/") data = ctxs[i]["steps"][j]["data"] if fmt == "graph-rewrites": return {"value":get_full_rewrite(trace.rewrites[i][j]), "content_type":"text/event-stream"} if fmt == "uops": return {"src":get_stdout(lambda: print_uops(data.uops or [])), "lang":"txt"} @@ -504,16 +506,17 @@ class Handler(HTTPRequestHandler): if url.path.endswith(".js"): content_type = "application/javascript" if url.path.endswith(".css"): content_type = "text/css" except FileNotFoundError: status_code = 404 - elif (query:=parse_qs(url.query)): - render_src = get_render(get_int(query, "ctx"), get_int(query, "step"), url.path.lstrip("/")) - if "content_type" in render_src: ret, content_type = render_src["value"], render_src["content_type"] - else: ret, content_type = json.dumps(render_src).encode(), "application/json" - if content_type == "text/event-stream": return self.stream_json(render_src["value"]) + elif url.path == "/ctxs": lst = [{**c, "steps":[{k:v for k, v in s.items() if k != "data"} for s in c["steps"]]} for c in ctxs] ret, content_type = json.dumps(lst).encode(), "application/json" elif url.path == "/get_profile" and profile_ret: ret, content_type = profile_ret, "application/octet-stream" - else: status_code = 404 + else: + if not (render_src:=get_render(self.path)): status_code = 404 + else: + if "content_type" in render_src: ret, content_type = render_src["value"], render_src["content_type"] + else: ret, content_type = json.dumps(render_src).encode(), "application/json" + if content_type == "text/event-stream": return self.stream_json(render_src["value"]) return self.send_data(ret, content_type, status_code)