mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
move perfetto to extra (#10994)
* move perfetto to extra * update TestViz and fix tests * remove perfetto.html from viz directory * work * mypy
This commit is contained in:
38
extra/perfetto/to_perfetto.py
Normal file
38
extra/perfetto/to_perfetto.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import sys, pickle, decimal, json
|
||||
from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent
|
||||
from tinygrad.helpers import tqdm, temp
|
||||
|
||||
devices:dict[str, tuple[decimal.Decimal, decimal.Decimal, int]] = {}
|
||||
def prep_ts(device:str, ts:decimal.Decimal, is_copy): return int(decimal.Decimal(ts) + devices[device][is_copy])
|
||||
def dev_to_pid(device:str, is_copy=False): return {"pid": devices[device][2], "tid": int(is_copy)}
|
||||
def dev_ev_to_perfetto_json(ev:ProfileDeviceEvent):
|
||||
devices[ev.device] = (ev.comp_tdiff, ev.copy_tdiff if ev.copy_tdiff is not None else ev.comp_tdiff, len(devices))
|
||||
return [{"name": "process_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "args": {"name": ev.device}},
|
||||
{"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 0, "args": {"name": "COMPUTE"}},
|
||||
{"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 1, "args": {"name": "COPY"}}]
|
||||
def range_ev_to_perfetto_json(ev:ProfileRangeEvent):
|
||||
return [{"name": ev.name, "ph": "X", "ts": prep_ts(ev.device, ev.st, ev.is_copy), "dur": float(ev.en-ev.st), **dev_to_pid(ev.device, ev.is_copy)}]
|
||||
def graph_ev_to_perfetto_json(ev:ProfileGraphEvent, reccnt):
|
||||
ret = []
|
||||
for i,e in enumerate(ev.ents):
|
||||
st, en = ev.sigs[e.st_id], ev.sigs[e.en_id]
|
||||
ret += [{"name": e.name, "ph": "X", "ts": prep_ts(e.device, st, e.is_copy), "dur": float(en-st), **dev_to_pid(e.device, e.is_copy)}]
|
||||
for dep in ev.deps[i]:
|
||||
d = ev.ents[dep]
|
||||
ret += [{"ph": "s", **dev_to_pid(d.device, d.is_copy), "id": reccnt+len(ret), "ts": prep_ts(d.device, ev.sigs[d.en_id], d.is_copy), "bp": "e"}]
|
||||
ret += [{"ph": "f", **dev_to_pid(e.device, e.is_copy), "id": reccnt+len(ret)-1, "ts": prep_ts(e.device, st, e.is_copy), "bp": "e"}]
|
||||
return ret
|
||||
def to_perfetto(profile:list[ProfileEvent]):
|
||||
# Start json with devices.
|
||||
prof_json = [x for ev in profile if isinstance(ev, ProfileDeviceEvent) for x in dev_ev_to_perfetto_json(ev)]
|
||||
for ev in tqdm(profile, desc="preparing profile"):
|
||||
if isinstance(ev, ProfileRangeEvent): prof_json += range_ev_to_perfetto_json(ev)
|
||||
elif isinstance(ev, ProfileGraphEvent): prof_json += graph_ev_to_perfetto_json(ev, reccnt=len(prof_json))
|
||||
return {"traceEvents": prof_json}
|
||||
|
||||
if __name__ == "__main__":
|
||||
fp = sys.argv[1]
|
||||
with open(fp, "rb") as f: profile = pickle.load(f)
|
||||
ret = to_perfetto(profile)
|
||||
with open(fp:=temp("perfetto.json", append_user=True), "w") as f: json.dump(ret, f)
|
||||
print(f"Saved perfetto output to {fp}. You can use upload this to the perfetto UI or Chrome devtools.")
|
||||
2
setup.py
2
setup.py
@@ -28,7 +28,7 @@ setup(name='tinygrad',
|
||||
'tinygrad.renderer', 'tinygrad.engine', 'tinygrad.viz', 'tinygrad.runtime', 'tinygrad.runtime.support', 'tinygrad.kernelize',
|
||||
'tinygrad.runtime.support.am', 'tinygrad.runtime.graph', 'tinygrad.shape', 'tinygrad.uop', 'tinygrad.opt',
|
||||
'tinygrad.runtime.support.nv'],
|
||||
package_data = {'tinygrad': ['py.typed'], 'tinygrad.viz': ['index.html', 'perfetto.html', 'assets/**/*', 'js/*']},
|
||||
package_data = {'tinygrad': ['py.typed'], 'tinygrad.viz': ['index.html', 'assets/**/*', 'js/*']},
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License"
|
||||
|
||||
@@ -195,50 +195,32 @@ class TestVizIntegration(TestViz):
|
||||
self.assertEqual(lst[1]["name"], prg.name)
|
||||
|
||||
from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry
|
||||
from tinygrad.viz.serve import to_perfetto
|
||||
from tinygrad.viz.serve import get_profile
|
||||
|
||||
class TextVizProfiler(unittest.TestCase):
|
||||
class TestVizProfiler(unittest.TestCase):
|
||||
def test_perfetto_node(self):
|
||||
prof = [ProfileRangeEvent(device='NV', name='E_2', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=False),
|
||||
ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100))]
|
||||
|
||||
j = json.loads(to_perfetto(prof))
|
||||
j = json.loads(get_profile(prof))
|
||||
|
||||
# Device regs always first
|
||||
self.assertEqual(j['traceEvents'][0]['name'], 'process_name')
|
||||
self.assertEqual(j['traceEvents'][0]['ph'], 'M')
|
||||
self.assertEqual(j['traceEvents'][0]['args']['name'], 'NV')
|
||||
|
||||
self.assertEqual(j['traceEvents'][1]['name'], 'thread_name')
|
||||
self.assertEqual(j['traceEvents'][1]['ph'], 'M')
|
||||
self.assertEqual(j['traceEvents'][1]['pid'], j['traceEvents'][0]['pid'])
|
||||
self.assertEqual(j['traceEvents'][1]['tid'], 0)
|
||||
self.assertEqual(j['traceEvents'][1]['args']['name'], 'COMPUTE')
|
||||
|
||||
self.assertEqual(j['traceEvents'][2]['name'], 'thread_name')
|
||||
self.assertEqual(j['traceEvents'][2]['ph'], 'M')
|
||||
self.assertEqual(j['traceEvents'][2]['pid'], j['traceEvents'][0]['pid'])
|
||||
self.assertEqual(j['traceEvents'][2]['tid'], 1)
|
||||
self.assertEqual(j['traceEvents'][2]['args']['name'], 'COPY')
|
||||
|
||||
self.assertEqual(j['traceEvents'][3]['name'], 'E_2')
|
||||
self.assertEqual(j['traceEvents'][3]['ts'], 0)
|
||||
self.assertEqual(j['traceEvents'][3]['dur'], 10)
|
||||
self.assertEqual(j['traceEvents'][3]['ph'], 'X')
|
||||
self.assertEqual(j['traceEvents'][3]['pid'], j['traceEvents'][0]['pid'])
|
||||
self.assertEqual(j['traceEvents'][3]['tid'], 0)
|
||||
dev_events = j['devEvents']['NV']
|
||||
self.assertEqual(len(dev_events), 1)
|
||||
event = dev_events[0]
|
||||
self.assertEqual(event['name'], 'E_2')
|
||||
self.assertEqual(event['ts'], 0)
|
||||
self.assertEqual(event['dur'], 10)
|
||||
|
||||
def test_perfetto_copy_node(self):
|
||||
prof = [ProfileRangeEvent(device='NV', name='COPYxx', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=True),
|
||||
ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100))]
|
||||
|
||||
j = json.loads(to_perfetto(prof))
|
||||
j = json.loads(get_profile(prof))
|
||||
|
||||
self.assertEqual(j['traceEvents'][3]['name'], 'COPYxx')
|
||||
self.assertEqual(j['traceEvents'][3]['ts'], 900) # diff clock
|
||||
self.assertEqual(j['traceEvents'][3]['dur'], 10)
|
||||
self.assertEqual(j['traceEvents'][3]['ph'], 'X')
|
||||
self.assertEqual(j['traceEvents'][3]['tid'], 1)
|
||||
event = j['devEvents']['NV'][0]
|
||||
self.assertEqual(event['name'], 'COPYxx')
|
||||
self.assertEqual(event['ts'], 900) # diff clock
|
||||
self.assertEqual(event['dur'], 10)
|
||||
|
||||
def test_perfetto_graph(self):
|
||||
prof = [ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100)),
|
||||
@@ -248,25 +230,22 @@ class TextVizProfiler(unittest.TestCase):
|
||||
deps=[[], [0]],
|
||||
sigs=[decimal.Decimal(1000), decimal.Decimal(1002), decimal.Decimal(1004), decimal.Decimal(1008)])]
|
||||
|
||||
j = json.loads(to_perfetto(prof))
|
||||
j = json.loads(get_profile(prof))
|
||||
|
||||
# Device regs always first
|
||||
self.assertEqual(j['traceEvents'][0]['args']['name'], 'NV')
|
||||
self.assertEqual(j['traceEvents'][1]['args']['name'], 'COMPUTE')
|
||||
self.assertEqual(j['traceEvents'][2]['args']['name'], 'COPY')
|
||||
self.assertEqual(j['traceEvents'][3]['args']['name'], 'NV:1')
|
||||
self.assertEqual(j['traceEvents'][4]['args']['name'], 'COMPUTE')
|
||||
self.assertEqual(j['traceEvents'][5]['args']['name'], 'COPY')
|
||||
devices = list(j['devEvents'])
|
||||
self.assertEqual(devices[0], 'NV')
|
||||
self.assertEqual(devices[1], 'NV:1')
|
||||
|
||||
self.assertEqual(j['traceEvents'][6]['name'], 'E_25_4n2')
|
||||
self.assertEqual(j['traceEvents'][6]['ts'], 0)
|
||||
self.assertEqual(j['traceEvents'][6]['dur'], 2)
|
||||
self.assertEqual(j['traceEvents'][6]['pid'], j['traceEvents'][0]['pid'])
|
||||
nv_events = j['devEvents']['NV']
|
||||
self.assertEqual(nv_events[0]['name'], 'E_25_4n2')
|
||||
self.assertEqual(nv_events[0]['ts'], 0)
|
||||
self.assertEqual(nv_events[0]['dur'], 2)
|
||||
#self.assertEqual(j['devEvents'][6]['pid'], j['devEvents'][0]['pid'])
|
||||
|
||||
self.assertEqual(j['traceEvents'][7]['name'], 'NV -> NV:1')
|
||||
self.assertEqual(j['traceEvents'][7]['ts'], 954)
|
||||
self.assertEqual(j['traceEvents'][7]['dur'], 4)
|
||||
self.assertEqual(j['traceEvents'][7]['pid'], j['traceEvents'][3]['pid'])
|
||||
nv1_events = j['devEvents']['NV:1']
|
||||
self.assertEqual(nv1_events[0]['name'], 'NV -> NV:1')
|
||||
self.assertEqual(nv1_events[0]['ts'], 954)
|
||||
#self.assertEqual(j['devEvents'][7]['pid'], j['devEvents'][3]['pid'])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -236,21 +236,8 @@ async function renderProfiler() {
|
||||
displayGraph("profiler");
|
||||
d3.select(".metadata").html("");
|
||||
if (data != null) return;
|
||||
// fetch and process data
|
||||
const { traceEvents } = await (await fetch("/get_profile")).json();
|
||||
let st, et;
|
||||
const events = new Map();
|
||||
for (const e of traceEvents) {
|
||||
if (e.name === "process_name") events.set(e.pid, { name:e.args.name, events:[] });
|
||||
if (e.ph === "X") {
|
||||
if (st == null) [st, et] = [e.ts, e.ts+e.dur];
|
||||
else {
|
||||
st = Math.min(st, e.ts);
|
||||
et = Math.max(et, e.ts+e.dur);
|
||||
}
|
||||
events.get(e.pid).events.push(e);
|
||||
}
|
||||
}
|
||||
const { devEvents, st, et } = await (await fetch("/get_profile")).json();
|
||||
const events = new Map(Object.entries(devEvents));
|
||||
const kernelMap = new Map();
|
||||
for (const [i, c] of ctxs.entries()) kernelMap.set(c.function_name, { name:c.name, i });
|
||||
// place devices on the y axis and set vertical positions
|
||||
@@ -264,18 +251,18 @@ async function renderProfiler() {
|
||||
const nameMap = new Map();
|
||||
data = [];
|
||||
for (const [k, v] of events) {
|
||||
if (v.events.length === 0) continue;
|
||||
if (v.length === 0) continue;
|
||||
const div = deviceList.appendChild(document.createElement("div"));
|
||||
div.id = `pid-${k}`;
|
||||
div.innerText = v.name;
|
||||
div.id = k;
|
||||
div.innerText = k;
|
||||
div.style.padding = `${padding}px`;
|
||||
const { y:baseY, height:baseHeight } = rect(`#pid-${k}`);
|
||||
const { y:baseY, height:baseHeight } = rect(`#${k}`);
|
||||
// position events on the y axis, stack ones that overlap
|
||||
const levels = [];
|
||||
v.events.sort((a,b) => (a.ts-st) - (b.ts-st));
|
||||
v.sort((a,b) => (a.ts-st) - (b.ts-st));
|
||||
const levelHeight = baseHeight-padding;
|
||||
const offsetY = baseY-canvasTop+padding/2;
|
||||
for (const [i,e] of v.events.entries()) {
|
||||
for (const [i,e] of v.entries()) {
|
||||
// assign to the first free depth
|
||||
const start = e.ts-st;
|
||||
const end = start+e.dur;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, decimal, socketserver, functools
|
||||
import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, socketserver, functools
|
||||
from http.server import BaseHTTPRequestHandler
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from typing import Any, TypedDict, Generator
|
||||
@@ -92,33 +92,28 @@ def get_details(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None,
|
||||
if not ctx.bottom_up: next_sink = new_sink
|
||||
|
||||
# Profiler API
|
||||
devices:dict[str, tuple[decimal.Decimal, decimal.Decimal, int]] = {}
|
||||
def prep_ts(device:str, ts:decimal.Decimal, is_copy): return int(decimal.Decimal(ts) + devices[device][is_copy])
|
||||
def dev_to_pid(device:str, is_copy=False): return {"pid": devices[device][2], "tid": int(is_copy)}
|
||||
def dev_ev_to_perfetto_json(ev:ProfileDeviceEvent):
|
||||
devices[ev.device] = (ev.comp_tdiff, ev.copy_tdiff if ev.copy_tdiff is not None else ev.comp_tdiff, len(devices))
|
||||
return [{"name": "process_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "args": {"name": ev.device}},
|
||||
{"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 0, "args": {"name": "COMPUTE"}},
|
||||
{"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 1, "args": {"name": "COPY"}}]
|
||||
def range_ev_to_perfetto_json(ev:ProfileRangeEvent):
|
||||
return [{"name": ev.name, "ph": "X", "ts": prep_ts(ev.device, ev.st, ev.is_copy), "dur": float(ev.en-ev.st), **dev_to_pid(ev.device, ev.is_copy)}]
|
||||
def graph_ev_to_perfetto_json(ev:ProfileGraphEvent, reccnt):
|
||||
ret = []
|
||||
for i,e in enumerate(ev.ents):
|
||||
st, en = ev.sigs[e.st_id], ev.sigs[e.en_id]
|
||||
ret += [{"name": e.name, "ph": "X", "ts": prep_ts(e.device, st, e.is_copy), "dur": float(en-st), **dev_to_pid(e.device, e.is_copy)}]
|
||||
for dep in ev.deps[i]:
|
||||
d = ev.ents[dep]
|
||||
ret += [{"ph": "s", **dev_to_pid(d.device, d.is_copy), "id": reccnt+len(ret), "ts": prep_ts(d.device, ev.sigs[d.en_id], d.is_copy), "bp": "e"}]
|
||||
ret += [{"ph": "f", **dev_to_pid(e.device, e.is_copy), "id": reccnt+len(ret)-1, "ts": prep_ts(e.device, st, e.is_copy), "bp": "e"}]
|
||||
return ret
|
||||
def to_perfetto(profile:list[ProfileEvent]):
|
||||
# Start json with devices.
|
||||
prof_json = [x for ev in profile if isinstance(ev, ProfileDeviceEvent) for x in dev_ev_to_perfetto_json(ev)]
|
||||
for ev in tqdm(profile, desc="preparing profile"):
|
||||
if isinstance(ev, ProfileRangeEvent): prof_json += range_ev_to_perfetto_json(ev)
|
||||
elif isinstance(ev, ProfileGraphEvent): prof_json += graph_ev_to_perfetto_json(ev, reccnt=len(prof_json))
|
||||
return json.dumps({"traceEvents": prof_json}).encode() if len(prof_json) > 0 else None
|
||||
|
||||
def events_to_json(profile:list[ProfileEvent]):
|
||||
for e in profile:
|
||||
if isinstance(e, ProfileRangeEvent): yield (e.device, e.name, e.st, e.en, e.is_copy)
|
||||
if isinstance(e, ProfileGraphEvent):
|
||||
for ent in e.ents: yield (ent.device, ent.name, e.sigs[ent.st_id], e.sigs[ent.en_id], ent.is_copy)
|
||||
|
||||
def get_profile(profile:list[ProfileEvent]):
|
||||
# start by getting the time diffs
|
||||
devs = {e.device:(e.comp_tdiff, e.copy_tdiff if e.copy_tdiff is not None else e.comp_tdiff) for e in profile if isinstance(e,ProfileDeviceEvent)}
|
||||
# map events per device
|
||||
dev_events:dict[str, list] = {}
|
||||
min_ts:int|None = None
|
||||
max_ts:int|None = None
|
||||
for device, name, ts, en, is_copy in events_to_json(profile):
|
||||
time_diff = devs[device][is_copy]
|
||||
st = int(ts+time_diff)
|
||||
et = st if en is None else int(en+time_diff)
|
||||
dev_events.setdefault(device,[]).append({"name":name, "ts":st, "dur":et-st})
|
||||
if min_ts is None or st < min_ts: min_ts = st
|
||||
if max_ts is None or et > max_ts: max_ts = et
|
||||
return json.dumps({"devEvents":dev_events, "st":min_ts, "et":max_ts}).encode("utf-8")
|
||||
|
||||
# ** HTTP server
|
||||
|
||||
@@ -126,7 +121,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
ret, status_code, content_type = b"", 200, "text/html"
|
||||
|
||||
if (fn:={"/":"index", "/profiler":"perfetto"}.get((url:=urlparse(self.path)).path)):
|
||||
if (fn:={"/":"index"}.get((url:=urlparse(self.path)).path)):
|
||||
with open(os.path.join(os.path.dirname(__file__), f"{fn}.html"), "rb") as f: ret = f.read()
|
||||
elif self.path.startswith(("/assets/", "/js/")) and '/..' not in self.path:
|
||||
try:
|
||||
@@ -151,7 +146,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
# pass if client closed connection
|
||||
except (BrokenPipeError, ConnectionResetError): return
|
||||
ret, content_type = json.dumps(ctxs).encode(), "application/json"
|
||||
elif url.path == "/get_profile" and perfetto_profile is not None: ret, content_type = perfetto_profile, "application/json"
|
||||
elif url.path == "/get_profile" and profile_ret is not None: ret, content_type = profile_ret, "application/json"
|
||||
else: status_code = 404
|
||||
|
||||
# send response
|
||||
@@ -197,7 +192,7 @@ if __name__ == "__main__":
|
||||
# NOTE: this context is a tuple of list[keys] and list[values]
|
||||
ctxs = get_metadata(*contexts[:2]) if contexts is not None else []
|
||||
|
||||
perfetto_profile = to_perfetto(profile) if profile is not None else None
|
||||
profile_ret = get_profile(profile) if profile is not None else None
|
||||
|
||||
server = TCPServerWithReuse(('', PORT), Handler)
|
||||
reloader_thread = threading.Thread(target=reloader)
|
||||
|
||||
Reference in New Issue
Block a user