move perfetto to extra (#10994)

* move perfetto to extra

* update TestViz and fix tests

* remove perfetto.html from viz directory

* work

* mypy
This commit is contained in:
qazal
2025-06-27 01:53:54 +03:00
committed by GitHub
parent 712980e167
commit 1127302c46
6 changed files with 100 additions and 101 deletions

View File

@@ -0,0 +1,38 @@
import sys, pickle, decimal, json
from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent
from tinygrad.helpers import tqdm, temp
devices:dict[str, tuple[decimal.Decimal, decimal.Decimal, int]] = {}
def prep_ts(device:str, ts:decimal.Decimal, is_copy): return int(decimal.Decimal(ts) + devices[device][is_copy])
def dev_to_pid(device:str, is_copy=False): return {"pid": devices[device][2], "tid": int(is_copy)}
def dev_ev_to_perfetto_json(ev:ProfileDeviceEvent):
devices[ev.device] = (ev.comp_tdiff, ev.copy_tdiff if ev.copy_tdiff is not None else ev.comp_tdiff, len(devices))
return [{"name": "process_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "args": {"name": ev.device}},
{"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 0, "args": {"name": "COMPUTE"}},
{"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 1, "args": {"name": "COPY"}}]
def range_ev_to_perfetto_json(ev:ProfileRangeEvent):
return [{"name": ev.name, "ph": "X", "ts": prep_ts(ev.device, ev.st, ev.is_copy), "dur": float(ev.en-ev.st), **dev_to_pid(ev.device, ev.is_copy)}]
def graph_ev_to_perfetto_json(ev:ProfileGraphEvent, reccnt):
ret = []
for i,e in enumerate(ev.ents):
st, en = ev.sigs[e.st_id], ev.sigs[e.en_id]
ret += [{"name": e.name, "ph": "X", "ts": prep_ts(e.device, st, e.is_copy), "dur": float(en-st), **dev_to_pid(e.device, e.is_copy)}]
for dep in ev.deps[i]:
d = ev.ents[dep]
ret += [{"ph": "s", **dev_to_pid(d.device, d.is_copy), "id": reccnt+len(ret), "ts": prep_ts(d.device, ev.sigs[d.en_id], d.is_copy), "bp": "e"}]
ret += [{"ph": "f", **dev_to_pid(e.device, e.is_copy), "id": reccnt+len(ret)-1, "ts": prep_ts(e.device, st, e.is_copy), "bp": "e"}]
return ret
def to_perfetto(profile:list[ProfileEvent]):
# Start json with devices.
prof_json = [x for ev in profile if isinstance(ev, ProfileDeviceEvent) for x in dev_ev_to_perfetto_json(ev)]
for ev in tqdm(profile, desc="preparing profile"):
if isinstance(ev, ProfileRangeEvent): prof_json += range_ev_to_perfetto_json(ev)
elif isinstance(ev, ProfileGraphEvent): prof_json += graph_ev_to_perfetto_json(ev, reccnt=len(prof_json))
return {"traceEvents": prof_json}
if __name__ == "__main__":
fp = sys.argv[1]
with open(fp, "rb") as f: profile = pickle.load(f)
ret = to_perfetto(profile)
with open(fp:=temp("perfetto.json", append_user=True), "w") as f: json.dump(ret, f)
print(f"Saved perfetto output to {fp}. You can use upload this to the perfetto UI or Chrome devtools.")

View File

@@ -28,7 +28,7 @@ setup(name='tinygrad',
'tinygrad.renderer', 'tinygrad.engine', 'tinygrad.viz', 'tinygrad.runtime', 'tinygrad.runtime.support', 'tinygrad.kernelize',
'tinygrad.runtime.support.am', 'tinygrad.runtime.graph', 'tinygrad.shape', 'tinygrad.uop', 'tinygrad.opt',
'tinygrad.runtime.support.nv'],
package_data = {'tinygrad': ['py.typed'], 'tinygrad.viz': ['index.html', 'perfetto.html', 'assets/**/*', 'js/*']},
package_data = {'tinygrad': ['py.typed'], 'tinygrad.viz': ['index.html', 'assets/**/*', 'js/*']},
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License"

View File

@@ -195,50 +195,32 @@ class TestVizIntegration(TestViz):
self.assertEqual(lst[1]["name"], prg.name)
from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry
from tinygrad.viz.serve import to_perfetto
from tinygrad.viz.serve import get_profile
class TextVizProfiler(unittest.TestCase):
class TestVizProfiler(unittest.TestCase):
def test_perfetto_node(self):
prof = [ProfileRangeEvent(device='NV', name='E_2', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=False),
ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100))]
j = json.loads(to_perfetto(prof))
j = json.loads(get_profile(prof))
# Device regs always first
self.assertEqual(j['traceEvents'][0]['name'], 'process_name')
self.assertEqual(j['traceEvents'][0]['ph'], 'M')
self.assertEqual(j['traceEvents'][0]['args']['name'], 'NV')
self.assertEqual(j['traceEvents'][1]['name'], 'thread_name')
self.assertEqual(j['traceEvents'][1]['ph'], 'M')
self.assertEqual(j['traceEvents'][1]['pid'], j['traceEvents'][0]['pid'])
self.assertEqual(j['traceEvents'][1]['tid'], 0)
self.assertEqual(j['traceEvents'][1]['args']['name'], 'COMPUTE')
self.assertEqual(j['traceEvents'][2]['name'], 'thread_name')
self.assertEqual(j['traceEvents'][2]['ph'], 'M')
self.assertEqual(j['traceEvents'][2]['pid'], j['traceEvents'][0]['pid'])
self.assertEqual(j['traceEvents'][2]['tid'], 1)
self.assertEqual(j['traceEvents'][2]['args']['name'], 'COPY')
self.assertEqual(j['traceEvents'][3]['name'], 'E_2')
self.assertEqual(j['traceEvents'][3]['ts'], 0)
self.assertEqual(j['traceEvents'][3]['dur'], 10)
self.assertEqual(j['traceEvents'][3]['ph'], 'X')
self.assertEqual(j['traceEvents'][3]['pid'], j['traceEvents'][0]['pid'])
self.assertEqual(j['traceEvents'][3]['tid'], 0)
dev_events = j['devEvents']['NV']
self.assertEqual(len(dev_events), 1)
event = dev_events[0]
self.assertEqual(event['name'], 'E_2')
self.assertEqual(event['ts'], 0)
self.assertEqual(event['dur'], 10)
def test_perfetto_copy_node(self):
prof = [ProfileRangeEvent(device='NV', name='COPYxx', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=True),
ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100))]
j = json.loads(to_perfetto(prof))
j = json.loads(get_profile(prof))
self.assertEqual(j['traceEvents'][3]['name'], 'COPYxx')
self.assertEqual(j['traceEvents'][3]['ts'], 900) # diff clock
self.assertEqual(j['traceEvents'][3]['dur'], 10)
self.assertEqual(j['traceEvents'][3]['ph'], 'X')
self.assertEqual(j['traceEvents'][3]['tid'], 1)
event = j['devEvents']['NV'][0]
self.assertEqual(event['name'], 'COPYxx')
self.assertEqual(event['ts'], 900) # diff clock
self.assertEqual(event['dur'], 10)
def test_perfetto_graph(self):
prof = [ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100)),
@@ -248,25 +230,22 @@ class TextVizProfiler(unittest.TestCase):
deps=[[], [0]],
sigs=[decimal.Decimal(1000), decimal.Decimal(1002), decimal.Decimal(1004), decimal.Decimal(1008)])]
j = json.loads(to_perfetto(prof))
j = json.loads(get_profile(prof))
# Device regs always first
self.assertEqual(j['traceEvents'][0]['args']['name'], 'NV')
self.assertEqual(j['traceEvents'][1]['args']['name'], 'COMPUTE')
self.assertEqual(j['traceEvents'][2]['args']['name'], 'COPY')
self.assertEqual(j['traceEvents'][3]['args']['name'], 'NV:1')
self.assertEqual(j['traceEvents'][4]['args']['name'], 'COMPUTE')
self.assertEqual(j['traceEvents'][5]['args']['name'], 'COPY')
devices = list(j['devEvents'])
self.assertEqual(devices[0], 'NV')
self.assertEqual(devices[1], 'NV:1')
self.assertEqual(j['traceEvents'][6]['name'], 'E_25_4n2')
self.assertEqual(j['traceEvents'][6]['ts'], 0)
self.assertEqual(j['traceEvents'][6]['dur'], 2)
self.assertEqual(j['traceEvents'][6]['pid'], j['traceEvents'][0]['pid'])
nv_events = j['devEvents']['NV']
self.assertEqual(nv_events[0]['name'], 'E_25_4n2')
self.assertEqual(nv_events[0]['ts'], 0)
self.assertEqual(nv_events[0]['dur'], 2)
#self.assertEqual(j['devEvents'][6]['pid'], j['devEvents'][0]['pid'])
self.assertEqual(j['traceEvents'][7]['name'], 'NV -> NV:1')
self.assertEqual(j['traceEvents'][7]['ts'], 954)
self.assertEqual(j['traceEvents'][7]['dur'], 4)
self.assertEqual(j['traceEvents'][7]['pid'], j['traceEvents'][3]['pid'])
nv1_events = j['devEvents']['NV:1']
self.assertEqual(nv1_events[0]['name'], 'NV -> NV:1')
self.assertEqual(nv1_events[0]['ts'], 954)
#self.assertEqual(j['devEvents'][7]['pid'], j['devEvents'][3]['pid'])
if __name__ == "__main__":
unittest.main()

View File

@@ -236,21 +236,8 @@ async function renderProfiler() {
displayGraph("profiler");
d3.select(".metadata").html("");
if (data != null) return;
// fetch and process data
const { traceEvents } = await (await fetch("/get_profile")).json();
let st, et;
const events = new Map();
for (const e of traceEvents) {
if (e.name === "process_name") events.set(e.pid, { name:e.args.name, events:[] });
if (e.ph === "X") {
if (st == null) [st, et] = [e.ts, e.ts+e.dur];
else {
st = Math.min(st, e.ts);
et = Math.max(et, e.ts+e.dur);
}
events.get(e.pid).events.push(e);
}
}
const { devEvents, st, et } = await (await fetch("/get_profile")).json();
const events = new Map(Object.entries(devEvents));
const kernelMap = new Map();
for (const [i, c] of ctxs.entries()) kernelMap.set(c.function_name, { name:c.name, i });
// place devices on the y axis and set vertical positions
@@ -264,18 +251,18 @@ async function renderProfiler() {
const nameMap = new Map();
data = [];
for (const [k, v] of events) {
if (v.events.length === 0) continue;
if (v.length === 0) continue;
const div = deviceList.appendChild(document.createElement("div"));
div.id = `pid-${k}`;
div.innerText = v.name;
div.id = k;
div.innerText = k;
div.style.padding = `${padding}px`;
const { y:baseY, height:baseHeight } = rect(`#pid-${k}`);
const { y:baseY, height:baseHeight } = rect(`#${k}`);
// position events on the y axis, stack ones that overlap
const levels = [];
v.events.sort((a,b) => (a.ts-st) - (b.ts-st));
v.sort((a,b) => (a.ts-st) - (b.ts-st));
const levelHeight = baseHeight-padding;
const offsetY = baseY-canvasTop+padding/2;
for (const [i,e] of v.events.entries()) {
for (const [i,e] of v.entries()) {
// assign to the first free depth
const start = e.ts-st;
const end = start+e.dur;

View File

@@ -1,5 +1,5 @@
#!/usr/bin/env python3
import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, decimal, socketserver, functools
import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, socketserver, functools
from http.server import BaseHTTPRequestHandler
from urllib.parse import parse_qs, urlparse
from typing import Any, TypedDict, Generator
@@ -92,33 +92,28 @@ def get_details(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None,
if not ctx.bottom_up: next_sink = new_sink
# Profiler API
devices:dict[str, tuple[decimal.Decimal, decimal.Decimal, int]] = {}
def prep_ts(device:str, ts:decimal.Decimal, is_copy): return int(decimal.Decimal(ts) + devices[device][is_copy])
def dev_to_pid(device:str, is_copy=False): return {"pid": devices[device][2], "tid": int(is_copy)}
def dev_ev_to_perfetto_json(ev:ProfileDeviceEvent):
devices[ev.device] = (ev.comp_tdiff, ev.copy_tdiff if ev.copy_tdiff is not None else ev.comp_tdiff, len(devices))
return [{"name": "process_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "args": {"name": ev.device}},
{"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 0, "args": {"name": "COMPUTE"}},
{"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 1, "args": {"name": "COPY"}}]
def range_ev_to_perfetto_json(ev:ProfileRangeEvent):
return [{"name": ev.name, "ph": "X", "ts": prep_ts(ev.device, ev.st, ev.is_copy), "dur": float(ev.en-ev.st), **dev_to_pid(ev.device, ev.is_copy)}]
def graph_ev_to_perfetto_json(ev:ProfileGraphEvent, reccnt):
ret = []
for i,e in enumerate(ev.ents):
st, en = ev.sigs[e.st_id], ev.sigs[e.en_id]
ret += [{"name": e.name, "ph": "X", "ts": prep_ts(e.device, st, e.is_copy), "dur": float(en-st), **dev_to_pid(e.device, e.is_copy)}]
for dep in ev.deps[i]:
d = ev.ents[dep]
ret += [{"ph": "s", **dev_to_pid(d.device, d.is_copy), "id": reccnt+len(ret), "ts": prep_ts(d.device, ev.sigs[d.en_id], d.is_copy), "bp": "e"}]
ret += [{"ph": "f", **dev_to_pid(e.device, e.is_copy), "id": reccnt+len(ret)-1, "ts": prep_ts(e.device, st, e.is_copy), "bp": "e"}]
return ret
def to_perfetto(profile:list[ProfileEvent]):
# Start json with devices.
prof_json = [x for ev in profile if isinstance(ev, ProfileDeviceEvent) for x in dev_ev_to_perfetto_json(ev)]
for ev in tqdm(profile, desc="preparing profile"):
if isinstance(ev, ProfileRangeEvent): prof_json += range_ev_to_perfetto_json(ev)
elif isinstance(ev, ProfileGraphEvent): prof_json += graph_ev_to_perfetto_json(ev, reccnt=len(prof_json))
return json.dumps({"traceEvents": prof_json}).encode() if len(prof_json) > 0 else None
def events_to_json(profile:list[ProfileEvent]):
for e in profile:
if isinstance(e, ProfileRangeEvent): yield (e.device, e.name, e.st, e.en, e.is_copy)
if isinstance(e, ProfileGraphEvent):
for ent in e.ents: yield (ent.device, ent.name, e.sigs[ent.st_id], e.sigs[ent.en_id], ent.is_copy)
def get_profile(profile:list[ProfileEvent]):
# start by getting the time diffs
devs = {e.device:(e.comp_tdiff, e.copy_tdiff if e.copy_tdiff is not None else e.comp_tdiff) for e in profile if isinstance(e,ProfileDeviceEvent)}
# map events per device
dev_events:dict[str, list] = {}
min_ts:int|None = None
max_ts:int|None = None
for device, name, ts, en, is_copy in events_to_json(profile):
time_diff = devs[device][is_copy]
st = int(ts+time_diff)
et = st if en is None else int(en+time_diff)
dev_events.setdefault(device,[]).append({"name":name, "ts":st, "dur":et-st})
if min_ts is None or st < min_ts: min_ts = st
if max_ts is None or et > max_ts: max_ts = et
return json.dumps({"devEvents":dev_events, "st":min_ts, "et":max_ts}).encode("utf-8")
# ** HTTP server
@@ -126,7 +121,7 @@ class Handler(BaseHTTPRequestHandler):
def do_GET(self):
ret, status_code, content_type = b"", 200, "text/html"
if (fn:={"/":"index", "/profiler":"perfetto"}.get((url:=urlparse(self.path)).path)):
if (fn:={"/":"index"}.get((url:=urlparse(self.path)).path)):
with open(os.path.join(os.path.dirname(__file__), f"{fn}.html"), "rb") as f: ret = f.read()
elif self.path.startswith(("/assets/", "/js/")) and '/..' not in self.path:
try:
@@ -151,7 +146,7 @@ class Handler(BaseHTTPRequestHandler):
# pass if client closed connection
except (BrokenPipeError, ConnectionResetError): return
ret, content_type = json.dumps(ctxs).encode(), "application/json"
elif url.path == "/get_profile" and perfetto_profile is not None: ret, content_type = perfetto_profile, "application/json"
elif url.path == "/get_profile" and profile_ret is not None: ret, content_type = profile_ret, "application/json"
else: status_code = 404
# send response
@@ -197,7 +192,7 @@ if __name__ == "__main__":
# NOTE: this context is a tuple of list[keys] and list[values]
ctxs = get_metadata(*contexts[:2]) if contexts is not None else []
perfetto_profile = to_perfetto(profile) if profile is not None else None
profile_ret = get_profile(profile) if profile is not None else None
server = TCPServerWithReuse(('', PORT), Handler)
reloader_thread = threading.Thread(target=reloader)