mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
command line interface for sqtt viz (#13891)
* command line interface for sqtt viz * cleanup * api surface area * this confuses the llms * document
This commit is contained in:
60
extra/sqtt/roc.py
Normal file → Executable file
60
extra/sqtt/roc.py
Normal file → Executable file
@@ -1,11 +1,9 @@
|
||||
import ctypes, pathlib, argparse, pickle, re, functools, dataclasses, itertools, threading
|
||||
#!/usr/bin/env python3
|
||||
import ctypes, pathlib, argparse, pickle, dataclasses, threading
|
||||
from typing import Generator
|
||||
from tinygrad.helpers import temp, unwrap, DEBUG
|
||||
from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileProgramEvent
|
||||
from tinygrad.runtime.ops_amd import ProfileSQTTEvent, ProfilePMCEvent
|
||||
from tinygrad.runtime.autogen import llvm, rocprof
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
from tinygrad.viz.serve import llvm_disasm
|
||||
from tinygrad.runtime.ops_amd import ProfileSQTTEvent
|
||||
from tinygrad.runtime.autogen import rocprof
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class InstExec:
|
||||
@@ -117,26 +115,52 @@ def decode(sqtt_evs:list[ProfileSQTTEvent], disasms:dict[str, dict[int, tuple[st
|
||||
|
||||
def worker():
|
||||
try: rocprof.rocprof_trace_decoder_parse_data(copy_cb, trace_cb, isa_cb, None)
|
||||
except AttributeError as e: raise RuntimeError("Failed to find rocprof-trace-decoder. Run sudo ./extra/sqtt/install_sqtt_decoder.py to install") from e
|
||||
except AttributeError as e:
|
||||
raise RuntimeError("Failed to find rocprof-trace-decoder. Run sudo ./extra/sqtt/install_sqtt_decoder.py to install") from e
|
||||
(t:=threading.Thread(target=worker, daemon=True)).start()
|
||||
t.join()
|
||||
return ROCParseCtx
|
||||
|
||||
def print_pmc(events:list[ProfilePMCEvent]) -> None:
|
||||
from tinygrad.viz.serve import unpack_pmc
|
||||
def print_data(data:dict) -> None:
|
||||
from tabulate import tabulate
|
||||
for e in events:
|
||||
print("**", e.kern)
|
||||
data = unpack_pmc(e)
|
||||
print(tabulate([r[:-1] for r in data["rows"]], headers=data["cols"], tablefmt="github"))
|
||||
# plaintext
|
||||
if "src" in data: print(data["src"])
|
||||
# table format
|
||||
elif "cols" in data:
|
||||
print(tabulate([r[:len(data["cols"])] for r in data["rows"]], headers=data["cols"], tablefmt="github"))
|
||||
|
||||
def main() -> None:
|
||||
import tinygrad.viz.serve as viz
|
||||
viz.ctxs = []
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--profile', type=pathlib.Path, help='Path to profile', default=pathlib.Path(temp("profile.pkl", append_user=True)))
|
||||
parser.add_argument('--profile', type=pathlib.Path, metavar="PATH", help='Path to profile (optional file, default: latest profile)',
|
||||
default=pathlib.Path(temp("profile.pkl", append_user=True)))
|
||||
parser.add_argument('--kernel', type=str, default=None, metavar="NAME", help='Kernel to focus on (optional name, default: all kernels)')
|
||||
parser.add_argument('-n', type=int, default=3, metavar="NUM", help='Max traces to print (optional number, default: 3 traces)')
|
||||
args = parser.parse_args()
|
||||
|
||||
with args.profile.open("rb") as f: profile = pickle.load(f)
|
||||
#rctx = decode(profile, disasm)
|
||||
#print('SQTT:', rctx.inst_execs.keys())
|
||||
|
||||
print_pmc([ev for ev in profile if isinstance(ev, ProfilePMCEvent)])
|
||||
viz.get_profile(profile)
|
||||
|
||||
# List all kernels
|
||||
if args.kernel is None:
|
||||
for c in viz.ctxs:
|
||||
print(c["name"])
|
||||
for s in c["steps"]: print(" "+s["name"])
|
||||
return None
|
||||
|
||||
# Find kernel trace
|
||||
trace = next((c for c in viz.ctxs if c["name"] == f"Exec {args.kernel}"), None)
|
||||
if not trace: raise RuntimeError(f"no matching trace for {args.kernel}")
|
||||
n = 0
|
||||
for s in trace["steps"]:
|
||||
print(s["name"])
|
||||
data = viz.get_render(s["query"])
|
||||
print_data(data)
|
||||
n += 1
|
||||
if n > args.n: break
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -197,7 +197,7 @@ AMD_CC, CPU_CC, NV_CC, CUDA_CC = ContextVar("AMD_CC", ""), ContextVar("CPU_CC",
|
||||
QCOM_CC = ContextVar("QCOM_CC", "")
|
||||
# VIZ implies PROFILE, but you can run PROFILE without VIZ
|
||||
VIZ = ContextVar("VIZ", 0)
|
||||
PROFILE = ContextVar("PROFILE", VIZ.value)
|
||||
PROFILE = ContextVar("PROFILE", abs(VIZ.value))
|
||||
SPEC = ContextVar("SPEC", 1)
|
||||
# TODO: disable by default due to speed
|
||||
IGNORE_OOB = ContextVar("IGNORE_OOB", 1)
|
||||
|
||||
@@ -1176,7 +1176,7 @@ if TRACK_MATCH_STATS or PROFILE:
|
||||
print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}")
|
||||
pickle.dump(RewriteTrace(tracked_keys, tracked_ctxs, uop_fields), f)
|
||||
if VIZ > 0: return launch_viz("VIZ", temp("rewrites.pkl", append_user=True))
|
||||
if getenv("PRINT_MATCH_STATS", TRACK_MATCH_STATS.value):
|
||||
if getenv("PRINT_MATCH_STATS", TRACK_MATCH_STATS.value and VIZ.value>=0):
|
||||
ret = [0,0,0.0,0.0]
|
||||
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]):
|
||||
loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}"
|
||||
|
||||
@@ -266,7 +266,7 @@ def load_counters(profile:list[ProfileEvent]) -> None:
|
||||
# run our decoder on startup, we don't use this since it only works on gfx11
|
||||
from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets
|
||||
for e in sqtt: parse_sqtt_print_packets(e.blob)
|
||||
ctxs.append({"name":f"Exec {name} n{run_number[k]}", "steps":steps})
|
||||
ctxs.append({"name":f"Exec {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps})
|
||||
|
||||
# ** SQTT OCC only unpacks wave start, end time and SIMD location
|
||||
|
||||
@@ -424,7 +424,9 @@ def amdgpu_cfg(lib:bytes, target:int) -> dict:
|
||||
|
||||
# ** Main render function to get the complete details about a trace event
|
||||
|
||||
def get_render(i:int, j:int, fmt:str) -> dict:
|
||||
def get_render(query:str) -> dict:
|
||||
url = urlparse(query)
|
||||
i, j, fmt = get_int(qs:=parse_qs(url.query), "ctx"), get_int(qs, "step"), url.path.lstrip("/")
|
||||
data = ctxs[i]["steps"][j]["data"]
|
||||
if fmt == "graph-rewrites": return {"value":get_full_rewrite(trace.rewrites[i][j]), "content_type":"text/event-stream"}
|
||||
if fmt == "uops": return {"src":get_stdout(lambda: print_uops(data.uops or [])), "lang":"txt"}
|
||||
@@ -504,16 +506,17 @@ class Handler(HTTPRequestHandler):
|
||||
if url.path.endswith(".js"): content_type = "application/javascript"
|
||||
if url.path.endswith(".css"): content_type = "text/css"
|
||||
except FileNotFoundError: status_code = 404
|
||||
elif (query:=parse_qs(url.query)):
|
||||
render_src = get_render(get_int(query, "ctx"), get_int(query, "step"), url.path.lstrip("/"))
|
||||
if "content_type" in render_src: ret, content_type = render_src["value"], render_src["content_type"]
|
||||
else: ret, content_type = json.dumps(render_src).encode(), "application/json"
|
||||
if content_type == "text/event-stream": return self.stream_json(render_src["value"])
|
||||
|
||||
elif url.path == "/ctxs":
|
||||
lst = [{**c, "steps":[{k:v for k, v in s.items() if k != "data"} for s in c["steps"]]} for c in ctxs]
|
||||
ret, content_type = json.dumps(lst).encode(), "application/json"
|
||||
elif url.path == "/get_profile" and profile_ret: ret, content_type = profile_ret, "application/octet-stream"
|
||||
else: status_code = 404
|
||||
else:
|
||||
if not (render_src:=get_render(self.path)): status_code = 404
|
||||
else:
|
||||
if "content_type" in render_src: ret, content_type = render_src["value"], render_src["content_type"]
|
||||
else: ret, content_type = json.dumps(render_src).encode(), "application/json"
|
||||
if content_type == "text/event-stream": return self.stream_json(render_src["value"])
|
||||
|
||||
return self.send_data(ret, content_type, status_code)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user