diff --git a/tinygrad/viz/perfetto.html b/extra/perfetto/perfetto.html
similarity index 100%
rename from tinygrad/viz/perfetto.html
rename to extra/perfetto/perfetto.html
diff --git a/extra/perfetto/to_perfetto.py b/extra/perfetto/to_perfetto.py
new file mode 100644
index 0000000000..814fa0ae46
--- /dev/null
+++ b/extra/perfetto/to_perfetto.py
@@ -0,0 +1,38 @@
+import sys, pickle, decimal, json
+from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent
+from tinygrad.helpers import tqdm, temp
+
+devices:dict[str, tuple[decimal.Decimal, decimal.Decimal, int]] = {}
+def prep_ts(device:str, ts:decimal.Decimal, is_copy): return int(decimal.Decimal(ts) + devices[device][is_copy])
+def dev_to_pid(device:str, is_copy=False): return {"pid": devices[device][2], "tid": int(is_copy)}
+def dev_ev_to_perfetto_json(ev:ProfileDeviceEvent):
+ devices[ev.device] = (ev.comp_tdiff, ev.copy_tdiff if ev.copy_tdiff is not None else ev.comp_tdiff, len(devices))
+ return [{"name": "process_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "args": {"name": ev.device}},
+ {"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 0, "args": {"name": "COMPUTE"}},
+ {"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 1, "args": {"name": "COPY"}}]
+def range_ev_to_perfetto_json(ev:ProfileRangeEvent):
+ return [{"name": ev.name, "ph": "X", "ts": prep_ts(ev.device, ev.st, ev.is_copy), "dur": float(ev.en-ev.st), **dev_to_pid(ev.device, ev.is_copy)}]
+def graph_ev_to_perfetto_json(ev:ProfileGraphEvent, reccnt):
+ ret = []
+ for i,e in enumerate(ev.ents):
+ st, en = ev.sigs[e.st_id], ev.sigs[e.en_id]
+ ret += [{"name": e.name, "ph": "X", "ts": prep_ts(e.device, st, e.is_copy), "dur": float(en-st), **dev_to_pid(e.device, e.is_copy)}]
+ for dep in ev.deps[i]:
+ d = ev.ents[dep]
+ ret += [{"ph": "s", **dev_to_pid(d.device, d.is_copy), "id": reccnt+len(ret), "ts": prep_ts(d.device, ev.sigs[d.en_id], d.is_copy), "bp": "e"}]
+ ret += [{"ph": "f", **dev_to_pid(e.device, e.is_copy), "id": reccnt+len(ret)-1, "ts": prep_ts(e.device, st, e.is_copy), "bp": "e"}]
+ return ret
+def to_perfetto(profile:list[ProfileEvent]):
+ # Start json with devices.
+ prof_json = [x for ev in profile if isinstance(ev, ProfileDeviceEvent) for x in dev_ev_to_perfetto_json(ev)]
+ for ev in tqdm(profile, desc="preparing profile"):
+ if isinstance(ev, ProfileRangeEvent): prof_json += range_ev_to_perfetto_json(ev)
+ elif isinstance(ev, ProfileGraphEvent): prof_json += graph_ev_to_perfetto_json(ev, reccnt=len(prof_json))
+ return {"traceEvents": prof_json}
+
+if __name__ == "__main__":
+ fp = sys.argv[1]
+ with open(fp, "rb") as f: profile = pickle.load(f)
+ ret = to_perfetto(profile)
+ with open(fp:=temp("perfetto.json", append_user=True), "w") as f: json.dump(ret, f)
+ print(f"Saved perfetto output to {fp}. You can use upload this to the perfetto UI or Chrome devtools.")
diff --git a/setup.py b/setup.py
index 3f56dee427..5d4d0fc3c6 100644
--- a/setup.py
+++ b/setup.py
@@ -28,7 +28,7 @@ setup(name='tinygrad',
'tinygrad.renderer', 'tinygrad.engine', 'tinygrad.viz', 'tinygrad.runtime', 'tinygrad.runtime.support', 'tinygrad.kernelize',
'tinygrad.runtime.support.am', 'tinygrad.runtime.graph', 'tinygrad.shape', 'tinygrad.uop', 'tinygrad.opt',
'tinygrad.runtime.support.nv'],
- package_data = {'tinygrad': ['py.typed'], 'tinygrad.viz': ['index.html', 'perfetto.html', 'assets/**/*', 'js/*']},
+ package_data = {'tinygrad': ['py.typed'], 'tinygrad.viz': ['index.html', 'assets/**/*', 'js/*']},
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License"
diff --git a/test/unit/test_viz.py b/test/unit/test_viz.py
index a2fe5e8012..0d0044f21b 100644
--- a/test/unit/test_viz.py
+++ b/test/unit/test_viz.py
@@ -195,50 +195,32 @@ class TestVizIntegration(TestViz):
self.assertEqual(lst[1]["name"], prg.name)
from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry
-from tinygrad.viz.serve import to_perfetto
+from tinygrad.viz.serve import get_profile
-class TextVizProfiler(unittest.TestCase):
+class TestVizProfiler(unittest.TestCase):
def test_perfetto_node(self):
prof = [ProfileRangeEvent(device='NV', name='E_2', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=False),
ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100))]
- j = json.loads(to_perfetto(prof))
+ j = json.loads(get_profile(prof))
- # Device regs always first
- self.assertEqual(j['traceEvents'][0]['name'], 'process_name')
- self.assertEqual(j['traceEvents'][0]['ph'], 'M')
- self.assertEqual(j['traceEvents'][0]['args']['name'], 'NV')
-
- self.assertEqual(j['traceEvents'][1]['name'], 'thread_name')
- self.assertEqual(j['traceEvents'][1]['ph'], 'M')
- self.assertEqual(j['traceEvents'][1]['pid'], j['traceEvents'][0]['pid'])
- self.assertEqual(j['traceEvents'][1]['tid'], 0)
- self.assertEqual(j['traceEvents'][1]['args']['name'], 'COMPUTE')
-
- self.assertEqual(j['traceEvents'][2]['name'], 'thread_name')
- self.assertEqual(j['traceEvents'][2]['ph'], 'M')
- self.assertEqual(j['traceEvents'][2]['pid'], j['traceEvents'][0]['pid'])
- self.assertEqual(j['traceEvents'][2]['tid'], 1)
- self.assertEqual(j['traceEvents'][2]['args']['name'], 'COPY')
-
- self.assertEqual(j['traceEvents'][3]['name'], 'E_2')
- self.assertEqual(j['traceEvents'][3]['ts'], 0)
- self.assertEqual(j['traceEvents'][3]['dur'], 10)
- self.assertEqual(j['traceEvents'][3]['ph'], 'X')
- self.assertEqual(j['traceEvents'][3]['pid'], j['traceEvents'][0]['pid'])
- self.assertEqual(j['traceEvents'][3]['tid'], 0)
+ dev_events = j['devEvents']['NV']
+ self.assertEqual(len(dev_events), 1)
+ event = dev_events[0]
+ self.assertEqual(event['name'], 'E_2')
+ self.assertEqual(event['ts'], 0)
+ self.assertEqual(event['dur'], 10)
def test_perfetto_copy_node(self):
prof = [ProfileRangeEvent(device='NV', name='COPYxx', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=True),
ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100))]
- j = json.loads(to_perfetto(prof))
+ j = json.loads(get_profile(prof))
- self.assertEqual(j['traceEvents'][3]['name'], 'COPYxx')
- self.assertEqual(j['traceEvents'][3]['ts'], 900) # diff clock
- self.assertEqual(j['traceEvents'][3]['dur'], 10)
- self.assertEqual(j['traceEvents'][3]['ph'], 'X')
- self.assertEqual(j['traceEvents'][3]['tid'], 1)
+ event = j['devEvents']['NV'][0]
+ self.assertEqual(event['name'], 'COPYxx')
+ self.assertEqual(event['ts'], 900) # diff clock
+ self.assertEqual(event['dur'], 10)
def test_perfetto_graph(self):
prof = [ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100)),
@@ -248,25 +230,22 @@ class TextVizProfiler(unittest.TestCase):
deps=[[], [0]],
sigs=[decimal.Decimal(1000), decimal.Decimal(1002), decimal.Decimal(1004), decimal.Decimal(1008)])]
- j = json.loads(to_perfetto(prof))
+ j = json.loads(get_profile(prof))
- # Device regs always first
- self.assertEqual(j['traceEvents'][0]['args']['name'], 'NV')
- self.assertEqual(j['traceEvents'][1]['args']['name'], 'COMPUTE')
- self.assertEqual(j['traceEvents'][2]['args']['name'], 'COPY')
- self.assertEqual(j['traceEvents'][3]['args']['name'], 'NV:1')
- self.assertEqual(j['traceEvents'][4]['args']['name'], 'COMPUTE')
- self.assertEqual(j['traceEvents'][5]['args']['name'], 'COPY')
+ devices = list(j['devEvents'])
+ self.assertEqual(devices[0], 'NV')
+ self.assertEqual(devices[1], 'NV:1')
- self.assertEqual(j['traceEvents'][6]['name'], 'E_25_4n2')
- self.assertEqual(j['traceEvents'][6]['ts'], 0)
- self.assertEqual(j['traceEvents'][6]['dur'], 2)
- self.assertEqual(j['traceEvents'][6]['pid'], j['traceEvents'][0]['pid'])
+ nv_events = j['devEvents']['NV']
+ self.assertEqual(nv_events[0]['name'], 'E_25_4n2')
+ self.assertEqual(nv_events[0]['ts'], 0)
+ self.assertEqual(nv_events[0]['dur'], 2)
+ #self.assertEqual(j['devEvents'][6]['pid'], j['devEvents'][0]['pid'])
- self.assertEqual(j['traceEvents'][7]['name'], 'NV -> NV:1')
- self.assertEqual(j['traceEvents'][7]['ts'], 954)
- self.assertEqual(j['traceEvents'][7]['dur'], 4)
- self.assertEqual(j['traceEvents'][7]['pid'], j['traceEvents'][3]['pid'])
+ nv1_events = j['devEvents']['NV:1']
+ self.assertEqual(nv1_events[0]['name'], 'NV -> NV:1')
+ self.assertEqual(nv1_events[0]['ts'], 954)
+ #self.assertEqual(j['devEvents'][7]['pid'], j['devEvents'][3]['pid'])
if __name__ == "__main__":
unittest.main()
diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js
index 4d05d99328..63416fca16 100644
--- a/tinygrad/viz/js/index.js
+++ b/tinygrad/viz/js/index.js
@@ -236,21 +236,8 @@ async function renderProfiler() {
displayGraph("profiler");
d3.select(".metadata").html("");
if (data != null) return;
- // fetch and process data
- const { traceEvents } = await (await fetch("/get_profile")).json();
- let st, et;
- const events = new Map();
- for (const e of traceEvents) {
- if (e.name === "process_name") events.set(e.pid, { name:e.args.name, events:[] });
- if (e.ph === "X") {
- if (st == null) [st, et] = [e.ts, e.ts+e.dur];
- else {
- st = Math.min(st, e.ts);
- et = Math.max(et, e.ts+e.dur);
- }
- events.get(e.pid).events.push(e);
- }
- }
+ const { devEvents, st, et } = await (await fetch("/get_profile")).json();
+ const events = new Map(Object.entries(devEvents));
const kernelMap = new Map();
for (const [i, c] of ctxs.entries()) kernelMap.set(c.function_name, { name:c.name, i });
// place devices on the y axis and set vertical positions
@@ -264,18 +251,18 @@ async function renderProfiler() {
const nameMap = new Map();
data = [];
for (const [k, v] of events) {
- if (v.events.length === 0) continue;
+ if (v.length === 0) continue;
const div = deviceList.appendChild(document.createElement("div"));
- div.id = `pid-${k}`;
- div.innerText = v.name;
+ div.id = k;
+ div.innerText = k;
div.style.padding = `${padding}px`;
- const { y:baseY, height:baseHeight } = rect(`#pid-${k}`);
+ const { y:baseY, height:baseHeight } = rect(`#${k}`);
// position events on the y axis, stack ones that overlap
const levels = [];
- v.events.sort((a,b) => (a.ts-st) - (b.ts-st));
+ v.sort((a,b) => (a.ts-st) - (b.ts-st));
const levelHeight = baseHeight-padding;
const offsetY = baseY-canvasTop+padding/2;
- for (const [i,e] of v.events.entries()) {
+ for (const [i,e] of v.entries()) {
// assign to the first free depth
const start = e.ts-st;
const end = start+e.dur;
diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py
index 46674b9690..a2cd46c819 100755
--- a/tinygrad/viz/serve.py
+++ b/tinygrad/viz/serve.py
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
-import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, decimal, socketserver, functools
+import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, socketserver, functools
from http.server import BaseHTTPRequestHandler
from urllib.parse import parse_qs, urlparse
from typing import Any, TypedDict, Generator
@@ -92,33 +92,28 @@ def get_details(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None,
if not ctx.bottom_up: next_sink = new_sink
# Profiler API
-devices:dict[str, tuple[decimal.Decimal, decimal.Decimal, int]] = {}
-def prep_ts(device:str, ts:decimal.Decimal, is_copy): return int(decimal.Decimal(ts) + devices[device][is_copy])
-def dev_to_pid(device:str, is_copy=False): return {"pid": devices[device][2], "tid": int(is_copy)}
-def dev_ev_to_perfetto_json(ev:ProfileDeviceEvent):
- devices[ev.device] = (ev.comp_tdiff, ev.copy_tdiff if ev.copy_tdiff is not None else ev.comp_tdiff, len(devices))
- return [{"name": "process_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "args": {"name": ev.device}},
- {"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 0, "args": {"name": "COMPUTE"}},
- {"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 1, "args": {"name": "COPY"}}]
-def range_ev_to_perfetto_json(ev:ProfileRangeEvent):
- return [{"name": ev.name, "ph": "X", "ts": prep_ts(ev.device, ev.st, ev.is_copy), "dur": float(ev.en-ev.st), **dev_to_pid(ev.device, ev.is_copy)}]
-def graph_ev_to_perfetto_json(ev:ProfileGraphEvent, reccnt):
- ret = []
- for i,e in enumerate(ev.ents):
- st, en = ev.sigs[e.st_id], ev.sigs[e.en_id]
- ret += [{"name": e.name, "ph": "X", "ts": prep_ts(e.device, st, e.is_copy), "dur": float(en-st), **dev_to_pid(e.device, e.is_copy)}]
- for dep in ev.deps[i]:
- d = ev.ents[dep]
- ret += [{"ph": "s", **dev_to_pid(d.device, d.is_copy), "id": reccnt+len(ret), "ts": prep_ts(d.device, ev.sigs[d.en_id], d.is_copy), "bp": "e"}]
- ret += [{"ph": "f", **dev_to_pid(e.device, e.is_copy), "id": reccnt+len(ret)-1, "ts": prep_ts(e.device, st, e.is_copy), "bp": "e"}]
- return ret
-def to_perfetto(profile:list[ProfileEvent]):
- # Start json with devices.
- prof_json = [x for ev in profile if isinstance(ev, ProfileDeviceEvent) for x in dev_ev_to_perfetto_json(ev)]
- for ev in tqdm(profile, desc="preparing profile"):
- if isinstance(ev, ProfileRangeEvent): prof_json += range_ev_to_perfetto_json(ev)
- elif isinstance(ev, ProfileGraphEvent): prof_json += graph_ev_to_perfetto_json(ev, reccnt=len(prof_json))
- return json.dumps({"traceEvents": prof_json}).encode() if len(prof_json) > 0 else None
+
+def events_to_json(profile:list[ProfileEvent]):
+ for e in profile:
+ if isinstance(e, ProfileRangeEvent): yield (e.device, e.name, e.st, e.en, e.is_copy)
+ if isinstance(e, ProfileGraphEvent):
+ for ent in e.ents: yield (ent.device, ent.name, e.sigs[ent.st_id], e.sigs[ent.en_id], ent.is_copy)
+
+def get_profile(profile:list[ProfileEvent]):
+ # start by getting the time diffs
+ devs = {e.device:(e.comp_tdiff, e.copy_tdiff if e.copy_tdiff is not None else e.comp_tdiff) for e in profile if isinstance(e,ProfileDeviceEvent)}
+ # map events per device
+ dev_events:dict[str, list] = {}
+ min_ts:int|None = None
+ max_ts:int|None = None
+ for device, name, ts, en, is_copy in events_to_json(profile):
+ time_diff = devs[device][is_copy]
+ st = int(ts+time_diff)
+ et = st if en is None else int(en+time_diff)
+ dev_events.setdefault(device,[]).append({"name":name, "ts":st, "dur":et-st})
+ if min_ts is None or st < min_ts: min_ts = st
+ if max_ts is None or et > max_ts: max_ts = et
+ return json.dumps({"devEvents":dev_events, "st":min_ts, "et":max_ts}).encode("utf-8")
# ** HTTP server
@@ -126,7 +121,7 @@ class Handler(BaseHTTPRequestHandler):
def do_GET(self):
ret, status_code, content_type = b"", 200, "text/html"
- if (fn:={"/":"index", "/profiler":"perfetto"}.get((url:=urlparse(self.path)).path)):
+ if (fn:={"/":"index"}.get((url:=urlparse(self.path)).path)):
with open(os.path.join(os.path.dirname(__file__), f"{fn}.html"), "rb") as f: ret = f.read()
elif self.path.startswith(("/assets/", "/js/")) and '/..' not in self.path:
try:
@@ -151,7 +146,7 @@ class Handler(BaseHTTPRequestHandler):
# pass if client closed connection
except (BrokenPipeError, ConnectionResetError): return
ret, content_type = json.dumps(ctxs).encode(), "application/json"
- elif url.path == "/get_profile" and perfetto_profile is not None: ret, content_type = perfetto_profile, "application/json"
+ elif url.path == "/get_profile" and profile_ret is not None: ret, content_type = profile_ret, "application/json"
else: status_code = 404
# send response
@@ -197,7 +192,7 @@ if __name__ == "__main__":
# NOTE: this context is a tuple of list[keys] and list[values]
ctxs = get_metadata(*contexts[:2]) if contexts is not None else []
- perfetto_profile = to_perfetto(profile) if profile is not None else None
+ profile_ret = get_profile(profile) if profile is not None else None
server = TCPServerWithReuse(('', PORT), Handler)
reloader_thread = threading.Thread(target=reloader)