mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
command line interface for sqtt viz (#13891)
* command line interface for sqtt viz * cleanup * api surface area * this confuses the llms * document
This commit is contained in:
60
extra/sqtt/roc.py
Normal file → Executable file
60
extra/sqtt/roc.py
Normal file → Executable file
@@ -1,11 +1,9 @@
|
|||||||
import ctypes, pathlib, argparse, pickle, re, functools, dataclasses, itertools, threading
|
#!/usr/bin/env python3
|
||||||
|
import ctypes, pathlib, argparse, pickle, dataclasses, threading
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
from tinygrad.helpers import temp, unwrap, DEBUG
|
from tinygrad.helpers import temp, unwrap, DEBUG
|
||||||
from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileProgramEvent
|
from tinygrad.runtime.ops_amd import ProfileSQTTEvent
|
||||||
from tinygrad.runtime.ops_amd import ProfileSQTTEvent, ProfilePMCEvent
|
from tinygrad.runtime.autogen import rocprof
|
||||||
from tinygrad.runtime.autogen import llvm, rocprof
|
|
||||||
from tinygrad.runtime.support.elf import elf_loader
|
|
||||||
from tinygrad.viz.serve import llvm_disasm
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class InstExec:
|
class InstExec:
|
||||||
@@ -117,26 +115,52 @@ def decode(sqtt_evs:list[ProfileSQTTEvent], disasms:dict[str, dict[int, tuple[st
|
|||||||
|
|
||||||
def worker():
|
def worker():
|
||||||
try: rocprof.rocprof_trace_decoder_parse_data(copy_cb, trace_cb, isa_cb, None)
|
try: rocprof.rocprof_trace_decoder_parse_data(copy_cb, trace_cb, isa_cb, None)
|
||||||
except AttributeError as e: raise RuntimeError("Failed to find rocprof-trace-decoder. Run sudo ./extra/sqtt/install_sqtt_decoder.py to install") from e
|
except AttributeError as e:
|
||||||
|
raise RuntimeError("Failed to find rocprof-trace-decoder. Run sudo ./extra/sqtt/install_sqtt_decoder.py to install") from e
|
||||||
(t:=threading.Thread(target=worker, daemon=True)).start()
|
(t:=threading.Thread(target=worker, daemon=True)).start()
|
||||||
t.join()
|
t.join()
|
||||||
return ROCParseCtx
|
return ROCParseCtx
|
||||||
|
|
||||||
def print_pmc(events:list[ProfilePMCEvent]) -> None:
|
def print_data(data:dict) -> None:
|
||||||
from tinygrad.viz.serve import unpack_pmc
|
|
||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
for e in events:
|
# plaintext
|
||||||
print("**", e.kern)
|
if "src" in data: print(data["src"])
|
||||||
data = unpack_pmc(e)
|
# table format
|
||||||
print(tabulate([r[:-1] for r in data["rows"]], headers=data["cols"], tablefmt="github"))
|
elif "cols" in data:
|
||||||
|
print(tabulate([r[:len(data["cols"])] for r in data["rows"]], headers=data["cols"], tablefmt="github"))
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
import tinygrad.viz.serve as viz
|
||||||
|
viz.ctxs = []
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--profile', type=pathlib.Path, help='Path to profile', default=pathlib.Path(temp("profile.pkl", append_user=True)))
|
parser.add_argument('--profile', type=pathlib.Path, metavar="PATH", help='Path to profile (optional file, default: latest profile)',
|
||||||
|
default=pathlib.Path(temp("profile.pkl", append_user=True)))
|
||||||
|
parser.add_argument('--kernel', type=str, default=None, metavar="NAME", help='Kernel to focus on (optional name, default: all kernels)')
|
||||||
|
parser.add_argument('-n', type=int, default=3, metavar="NUM", help='Max traces to print (optional number, default: 3 traces)')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
with args.profile.open("rb") as f: profile = pickle.load(f)
|
with args.profile.open("rb") as f: profile = pickle.load(f)
|
||||||
#rctx = decode(profile, disasm)
|
|
||||||
#print('SQTT:', rctx.inst_execs.keys())
|
|
||||||
|
|
||||||
print_pmc([ev for ev in profile if isinstance(ev, ProfilePMCEvent)])
|
viz.get_profile(profile)
|
||||||
|
|
||||||
|
# List all kernels
|
||||||
|
if args.kernel is None:
|
||||||
|
for c in viz.ctxs:
|
||||||
|
print(c["name"])
|
||||||
|
for s in c["steps"]: print(" "+s["name"])
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find kernel trace
|
||||||
|
trace = next((c for c in viz.ctxs if c["name"] == f"Exec {args.kernel}"), None)
|
||||||
|
if not trace: raise RuntimeError(f"no matching trace for {args.kernel}")
|
||||||
|
n = 0
|
||||||
|
for s in trace["steps"]:
|
||||||
|
print(s["name"])
|
||||||
|
data = viz.get_render(s["query"])
|
||||||
|
print_data(data)
|
||||||
|
n += 1
|
||||||
|
if n > args.n: break
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|||||||
@@ -197,7 +197,7 @@ AMD_CC, CPU_CC, NV_CC, CUDA_CC = ContextVar("AMD_CC", ""), ContextVar("CPU_CC",
|
|||||||
QCOM_CC = ContextVar("QCOM_CC", "")
|
QCOM_CC = ContextVar("QCOM_CC", "")
|
||||||
# VIZ implies PROFILE, but you can run PROFILE without VIZ
|
# VIZ implies PROFILE, but you can run PROFILE without VIZ
|
||||||
VIZ = ContextVar("VIZ", 0)
|
VIZ = ContextVar("VIZ", 0)
|
||||||
PROFILE = ContextVar("PROFILE", VIZ.value)
|
PROFILE = ContextVar("PROFILE", abs(VIZ.value))
|
||||||
SPEC = ContextVar("SPEC", 1)
|
SPEC = ContextVar("SPEC", 1)
|
||||||
# TODO: disable by default due to speed
|
# TODO: disable by default due to speed
|
||||||
IGNORE_OOB = ContextVar("IGNORE_OOB", 1)
|
IGNORE_OOB = ContextVar("IGNORE_OOB", 1)
|
||||||
|
|||||||
@@ -1176,7 +1176,7 @@ if TRACK_MATCH_STATS or PROFILE:
|
|||||||
print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}")
|
print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}")
|
||||||
pickle.dump(RewriteTrace(tracked_keys, tracked_ctxs, uop_fields), f)
|
pickle.dump(RewriteTrace(tracked_keys, tracked_ctxs, uop_fields), f)
|
||||||
if VIZ > 0: return launch_viz("VIZ", temp("rewrites.pkl", append_user=True))
|
if VIZ > 0: return launch_viz("VIZ", temp("rewrites.pkl", append_user=True))
|
||||||
if getenv("PRINT_MATCH_STATS", TRACK_MATCH_STATS.value):
|
if getenv("PRINT_MATCH_STATS", TRACK_MATCH_STATS.value and VIZ.value>=0):
|
||||||
ret = [0,0,0.0,0.0]
|
ret = [0,0,0.0,0.0]
|
||||||
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]):
|
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]):
|
||||||
loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}"
|
loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}"
|
||||||
|
|||||||
@@ -266,7 +266,7 @@ def load_counters(profile:list[ProfileEvent]) -> None:
|
|||||||
# run our decoder on startup, we don't use this since it only works on gfx11
|
# run our decoder on startup, we don't use this since it only works on gfx11
|
||||||
from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets
|
from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets
|
||||||
for e in sqtt: parse_sqtt_print_packets(e.blob)
|
for e in sqtt: parse_sqtt_print_packets(e.blob)
|
||||||
ctxs.append({"name":f"Exec {name} n{run_number[k]}", "steps":steps})
|
ctxs.append({"name":f"Exec {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps})
|
||||||
|
|
||||||
# ** SQTT OCC only unpacks wave start, end time and SIMD location
|
# ** SQTT OCC only unpacks wave start, end time and SIMD location
|
||||||
|
|
||||||
@@ -424,7 +424,9 @@ def amdgpu_cfg(lib:bytes, target:int) -> dict:
|
|||||||
|
|
||||||
# ** Main render function to get the complete details about a trace event
|
# ** Main render function to get the complete details about a trace event
|
||||||
|
|
||||||
def get_render(i:int, j:int, fmt:str) -> dict:
|
def get_render(query:str) -> dict:
|
||||||
|
url = urlparse(query)
|
||||||
|
i, j, fmt = get_int(qs:=parse_qs(url.query), "ctx"), get_int(qs, "step"), url.path.lstrip("/")
|
||||||
data = ctxs[i]["steps"][j]["data"]
|
data = ctxs[i]["steps"][j]["data"]
|
||||||
if fmt == "graph-rewrites": return {"value":get_full_rewrite(trace.rewrites[i][j]), "content_type":"text/event-stream"}
|
if fmt == "graph-rewrites": return {"value":get_full_rewrite(trace.rewrites[i][j]), "content_type":"text/event-stream"}
|
||||||
if fmt == "uops": return {"src":get_stdout(lambda: print_uops(data.uops or [])), "lang":"txt"}
|
if fmt == "uops": return {"src":get_stdout(lambda: print_uops(data.uops or [])), "lang":"txt"}
|
||||||
@@ -504,16 +506,17 @@ class Handler(HTTPRequestHandler):
|
|||||||
if url.path.endswith(".js"): content_type = "application/javascript"
|
if url.path.endswith(".js"): content_type = "application/javascript"
|
||||||
if url.path.endswith(".css"): content_type = "text/css"
|
if url.path.endswith(".css"): content_type = "text/css"
|
||||||
except FileNotFoundError: status_code = 404
|
except FileNotFoundError: status_code = 404
|
||||||
elif (query:=parse_qs(url.query)):
|
|
||||||
render_src = get_render(get_int(query, "ctx"), get_int(query, "step"), url.path.lstrip("/"))
|
|
||||||
if "content_type" in render_src: ret, content_type = render_src["value"], render_src["content_type"]
|
|
||||||
else: ret, content_type = json.dumps(render_src).encode(), "application/json"
|
|
||||||
if content_type == "text/event-stream": return self.stream_json(render_src["value"])
|
|
||||||
elif url.path == "/ctxs":
|
elif url.path == "/ctxs":
|
||||||
lst = [{**c, "steps":[{k:v for k, v in s.items() if k != "data"} for s in c["steps"]]} for c in ctxs]
|
lst = [{**c, "steps":[{k:v for k, v in s.items() if k != "data"} for s in c["steps"]]} for c in ctxs]
|
||||||
ret, content_type = json.dumps(lst).encode(), "application/json"
|
ret, content_type = json.dumps(lst).encode(), "application/json"
|
||||||
elif url.path == "/get_profile" and profile_ret: ret, content_type = profile_ret, "application/octet-stream"
|
elif url.path == "/get_profile" and profile_ret: ret, content_type = profile_ret, "application/octet-stream"
|
||||||
else: status_code = 404
|
else:
|
||||||
|
if not (render_src:=get_render(self.path)): status_code = 404
|
||||||
|
else:
|
||||||
|
if "content_type" in render_src: ret, content_type = render_src["value"], render_src["content_type"]
|
||||||
|
else: ret, content_type = json.dumps(render_src).encode(), "application/json"
|
||||||
|
if content_type == "text/event-stream": return self.stream_json(render_src["value"])
|
||||||
|
|
||||||
return self.send_data(ret, content_type, status_code)
|
return self.send_data(ret, content_type, status_code)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user