diff --git a/tinygrad/viz/perfetto.html b/extra/perfetto/perfetto.html similarity index 100% rename from tinygrad/viz/perfetto.html rename to extra/perfetto/perfetto.html diff --git a/extra/perfetto/to_perfetto.py b/extra/perfetto/to_perfetto.py new file mode 100644 index 0000000000..814fa0ae46 --- /dev/null +++ b/extra/perfetto/to_perfetto.py @@ -0,0 +1,38 @@ +import sys, pickle, decimal, json +from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent +from tinygrad.helpers import tqdm, temp + +devices:dict[str, tuple[decimal.Decimal, decimal.Decimal, int]] = {} +def prep_ts(device:str, ts:decimal.Decimal, is_copy): return int(decimal.Decimal(ts) + devices[device][is_copy]) +def dev_to_pid(device:str, is_copy=False): return {"pid": devices[device][2], "tid": int(is_copy)} +def dev_ev_to_perfetto_json(ev:ProfileDeviceEvent): + devices[ev.device] = (ev.comp_tdiff, ev.copy_tdiff if ev.copy_tdiff is not None else ev.comp_tdiff, len(devices)) + return [{"name": "process_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "args": {"name": ev.device}}, + {"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 0, "args": {"name": "COMPUTE"}}, + {"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 1, "args": {"name": "COPY"}}] +def range_ev_to_perfetto_json(ev:ProfileRangeEvent): + return [{"name": ev.name, "ph": "X", "ts": prep_ts(ev.device, ev.st, ev.is_copy), "dur": float(ev.en-ev.st), **dev_to_pid(ev.device, ev.is_copy)}] +def graph_ev_to_perfetto_json(ev:ProfileGraphEvent, reccnt): + ret = [] + for i,e in enumerate(ev.ents): + st, en = ev.sigs[e.st_id], ev.sigs[e.en_id] + ret += [{"name": e.name, "ph": "X", "ts": prep_ts(e.device, st, e.is_copy), "dur": float(en-st), **dev_to_pid(e.device, e.is_copy)}] + for dep in ev.deps[i]: + d = ev.ents[dep] + ret += [{"ph": "s", **dev_to_pid(d.device, d.is_copy), "id": reccnt+len(ret), "ts": prep_ts(d.device, ev.sigs[d.en_id], d.is_copy), "bp": "e"}] + ret += [{"ph": "f", **dev_to_pid(e.device, e.is_copy), "id": reccnt+len(ret)-1, "ts": prep_ts(e.device, st, e.is_copy), "bp": "e"}] + return ret +def to_perfetto(profile:list[ProfileEvent]): + # Start json with devices. + prof_json = [x for ev in profile if isinstance(ev, ProfileDeviceEvent) for x in dev_ev_to_perfetto_json(ev)] + for ev in tqdm(profile, desc="preparing profile"): + if isinstance(ev, ProfileRangeEvent): prof_json += range_ev_to_perfetto_json(ev) + elif isinstance(ev, ProfileGraphEvent): prof_json += graph_ev_to_perfetto_json(ev, reccnt=len(prof_json)) + return {"traceEvents": prof_json} + +if __name__ == "__main__": + fp = sys.argv[1] + with open(fp, "rb") as f: profile = pickle.load(f) + ret = to_perfetto(profile) + with open(fp:=temp("perfetto.json", append_user=True), "w") as f: json.dump(ret, f) + print(f"Saved perfetto output to {fp}. You can use upload this to the perfetto UI or Chrome devtools.") diff --git a/setup.py b/setup.py index 3f56dee427..5d4d0fc3c6 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ setup(name='tinygrad', 'tinygrad.renderer', 'tinygrad.engine', 'tinygrad.viz', 'tinygrad.runtime', 'tinygrad.runtime.support', 'tinygrad.kernelize', 'tinygrad.runtime.support.am', 'tinygrad.runtime.graph', 'tinygrad.shape', 'tinygrad.uop', 'tinygrad.opt', 'tinygrad.runtime.support.nv'], - package_data = {'tinygrad': ['py.typed'], 'tinygrad.viz': ['index.html', 'perfetto.html', 'assets/**/*', 'js/*']}, + package_data = {'tinygrad': ['py.typed'], 'tinygrad.viz': ['index.html', 'assets/**/*', 'js/*']}, classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License" diff --git a/test/unit/test_viz.py b/test/unit/test_viz.py index a2fe5e8012..0d0044f21b 100644 --- a/test/unit/test_viz.py +++ b/test/unit/test_viz.py @@ -195,50 +195,32 @@ class TestVizIntegration(TestViz): self.assertEqual(lst[1]["name"], prg.name) from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry -from tinygrad.viz.serve import to_perfetto +from tinygrad.viz.serve import get_profile -class TextVizProfiler(unittest.TestCase): +class TestVizProfiler(unittest.TestCase): def test_perfetto_node(self): prof = [ProfileRangeEvent(device='NV', name='E_2', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=False), ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100))] - j = json.loads(to_perfetto(prof)) + j = json.loads(get_profile(prof)) - # Device regs always first - self.assertEqual(j['traceEvents'][0]['name'], 'process_name') - self.assertEqual(j['traceEvents'][0]['ph'], 'M') - self.assertEqual(j['traceEvents'][0]['args']['name'], 'NV') - - self.assertEqual(j['traceEvents'][1]['name'], 'thread_name') - self.assertEqual(j['traceEvents'][1]['ph'], 'M') - self.assertEqual(j['traceEvents'][1]['pid'], j['traceEvents'][0]['pid']) - self.assertEqual(j['traceEvents'][1]['tid'], 0) - self.assertEqual(j['traceEvents'][1]['args']['name'], 'COMPUTE') - - self.assertEqual(j['traceEvents'][2]['name'], 'thread_name') - self.assertEqual(j['traceEvents'][2]['ph'], 'M') - self.assertEqual(j['traceEvents'][2]['pid'], j['traceEvents'][0]['pid']) - self.assertEqual(j['traceEvents'][2]['tid'], 1) - self.assertEqual(j['traceEvents'][2]['args']['name'], 'COPY') - - self.assertEqual(j['traceEvents'][3]['name'], 'E_2') - self.assertEqual(j['traceEvents'][3]['ts'], 0) - self.assertEqual(j['traceEvents'][3]['dur'], 10) - self.assertEqual(j['traceEvents'][3]['ph'], 'X') - self.assertEqual(j['traceEvents'][3]['pid'], j['traceEvents'][0]['pid']) - self.assertEqual(j['traceEvents'][3]['tid'], 0) + dev_events = j['devEvents']['NV'] + self.assertEqual(len(dev_events), 1) + event = dev_events[0] + self.assertEqual(event['name'], 'E_2') + self.assertEqual(event['ts'], 0) + self.assertEqual(event['dur'], 10) def test_perfetto_copy_node(self): prof = [ProfileRangeEvent(device='NV', name='COPYxx', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=True), ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100))] - j = json.loads(to_perfetto(prof)) + j = json.loads(get_profile(prof)) - self.assertEqual(j['traceEvents'][3]['name'], 'COPYxx') - self.assertEqual(j['traceEvents'][3]['ts'], 900) # diff clock - self.assertEqual(j['traceEvents'][3]['dur'], 10) - self.assertEqual(j['traceEvents'][3]['ph'], 'X') - self.assertEqual(j['traceEvents'][3]['tid'], 1) + event = j['devEvents']['NV'][0] + self.assertEqual(event['name'], 'COPYxx') + self.assertEqual(event['ts'], 900) # diff clock + self.assertEqual(event['dur'], 10) def test_perfetto_graph(self): prof = [ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100)), @@ -248,25 +230,22 @@ class TextVizProfiler(unittest.TestCase): deps=[[], [0]], sigs=[decimal.Decimal(1000), decimal.Decimal(1002), decimal.Decimal(1004), decimal.Decimal(1008)])] - j = json.loads(to_perfetto(prof)) + j = json.loads(get_profile(prof)) - # Device regs always first - self.assertEqual(j['traceEvents'][0]['args']['name'], 'NV') - self.assertEqual(j['traceEvents'][1]['args']['name'], 'COMPUTE') - self.assertEqual(j['traceEvents'][2]['args']['name'], 'COPY') - self.assertEqual(j['traceEvents'][3]['args']['name'], 'NV:1') - self.assertEqual(j['traceEvents'][4]['args']['name'], 'COMPUTE') - self.assertEqual(j['traceEvents'][5]['args']['name'], 'COPY') + devices = list(j['devEvents']) + self.assertEqual(devices[0], 'NV') + self.assertEqual(devices[1], 'NV:1') - self.assertEqual(j['traceEvents'][6]['name'], 'E_25_4n2') - self.assertEqual(j['traceEvents'][6]['ts'], 0) - self.assertEqual(j['traceEvents'][6]['dur'], 2) - self.assertEqual(j['traceEvents'][6]['pid'], j['traceEvents'][0]['pid']) + nv_events = j['devEvents']['NV'] + self.assertEqual(nv_events[0]['name'], 'E_25_4n2') + self.assertEqual(nv_events[0]['ts'], 0) + self.assertEqual(nv_events[0]['dur'], 2) + #self.assertEqual(j['devEvents'][6]['pid'], j['devEvents'][0]['pid']) - self.assertEqual(j['traceEvents'][7]['name'], 'NV -> NV:1') - self.assertEqual(j['traceEvents'][7]['ts'], 954) - self.assertEqual(j['traceEvents'][7]['dur'], 4) - self.assertEqual(j['traceEvents'][7]['pid'], j['traceEvents'][3]['pid']) + nv1_events = j['devEvents']['NV:1'] + self.assertEqual(nv1_events[0]['name'], 'NV -> NV:1') + self.assertEqual(nv1_events[0]['ts'], 954) + #self.assertEqual(j['devEvents'][7]['pid'], j['devEvents'][3]['pid']) if __name__ == "__main__": unittest.main() diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index 4d05d99328..63416fca16 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -236,21 +236,8 @@ async function renderProfiler() { displayGraph("profiler"); d3.select(".metadata").html(""); if (data != null) return; - // fetch and process data - const { traceEvents } = await (await fetch("/get_profile")).json(); - let st, et; - const events = new Map(); - for (const e of traceEvents) { - if (e.name === "process_name") events.set(e.pid, { name:e.args.name, events:[] }); - if (e.ph === "X") { - if (st == null) [st, et] = [e.ts, e.ts+e.dur]; - else { - st = Math.min(st, e.ts); - et = Math.max(et, e.ts+e.dur); - } - events.get(e.pid).events.push(e); - } - } + const { devEvents, st, et } = await (await fetch("/get_profile")).json(); + const events = new Map(Object.entries(devEvents)); const kernelMap = new Map(); for (const [i, c] of ctxs.entries()) kernelMap.set(c.function_name, { name:c.name, i }); // place devices on the y axis and set vertical positions @@ -264,18 +251,18 @@ async function renderProfiler() { const nameMap = new Map(); data = []; for (const [k, v] of events) { - if (v.events.length === 0) continue; + if (v.length === 0) continue; const div = deviceList.appendChild(document.createElement("div")); - div.id = `pid-${k}`; - div.innerText = v.name; + div.id = k; + div.innerText = k; div.style.padding = `${padding}px`; - const { y:baseY, height:baseHeight } = rect(`#pid-${k}`); + const { y:baseY, height:baseHeight } = rect(`#${k}`); // position events on the y axis, stack ones that overlap const levels = []; - v.events.sort((a,b) => (a.ts-st) - (b.ts-st)); + v.sort((a,b) => (a.ts-st) - (b.ts-st)); const levelHeight = baseHeight-padding; const offsetY = baseY-canvasTop+padding/2; - for (const [i,e] of v.events.entries()) { + for (const [i,e] of v.entries()) { // assign to the first free depth const start = e.ts-st; const end = start+e.dur; diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 46674b9690..a2cd46c819 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, decimal, socketserver, functools +import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, socketserver, functools from http.server import BaseHTTPRequestHandler from urllib.parse import parse_qs, urlparse from typing import Any, TypedDict, Generator @@ -92,33 +92,28 @@ def get_details(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, if not ctx.bottom_up: next_sink = new_sink # Profiler API -devices:dict[str, tuple[decimal.Decimal, decimal.Decimal, int]] = {} -def prep_ts(device:str, ts:decimal.Decimal, is_copy): return int(decimal.Decimal(ts) + devices[device][is_copy]) -def dev_to_pid(device:str, is_copy=False): return {"pid": devices[device][2], "tid": int(is_copy)} -def dev_ev_to_perfetto_json(ev:ProfileDeviceEvent): - devices[ev.device] = (ev.comp_tdiff, ev.copy_tdiff if ev.copy_tdiff is not None else ev.comp_tdiff, len(devices)) - return [{"name": "process_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "args": {"name": ev.device}}, - {"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 0, "args": {"name": "COMPUTE"}}, - {"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 1, "args": {"name": "COPY"}}] -def range_ev_to_perfetto_json(ev:ProfileRangeEvent): - return [{"name": ev.name, "ph": "X", "ts": prep_ts(ev.device, ev.st, ev.is_copy), "dur": float(ev.en-ev.st), **dev_to_pid(ev.device, ev.is_copy)}] -def graph_ev_to_perfetto_json(ev:ProfileGraphEvent, reccnt): - ret = [] - for i,e in enumerate(ev.ents): - st, en = ev.sigs[e.st_id], ev.sigs[e.en_id] - ret += [{"name": e.name, "ph": "X", "ts": prep_ts(e.device, st, e.is_copy), "dur": float(en-st), **dev_to_pid(e.device, e.is_copy)}] - for dep in ev.deps[i]: - d = ev.ents[dep] - ret += [{"ph": "s", **dev_to_pid(d.device, d.is_copy), "id": reccnt+len(ret), "ts": prep_ts(d.device, ev.sigs[d.en_id], d.is_copy), "bp": "e"}] - ret += [{"ph": "f", **dev_to_pid(e.device, e.is_copy), "id": reccnt+len(ret)-1, "ts": prep_ts(e.device, st, e.is_copy), "bp": "e"}] - return ret -def to_perfetto(profile:list[ProfileEvent]): - # Start json with devices. - prof_json = [x for ev in profile if isinstance(ev, ProfileDeviceEvent) for x in dev_ev_to_perfetto_json(ev)] - for ev in tqdm(profile, desc="preparing profile"): - if isinstance(ev, ProfileRangeEvent): prof_json += range_ev_to_perfetto_json(ev) - elif isinstance(ev, ProfileGraphEvent): prof_json += graph_ev_to_perfetto_json(ev, reccnt=len(prof_json)) - return json.dumps({"traceEvents": prof_json}).encode() if len(prof_json) > 0 else None + +def events_to_json(profile:list[ProfileEvent]): + for e in profile: + if isinstance(e, ProfileRangeEvent): yield (e.device, e.name, e.st, e.en, e.is_copy) + if isinstance(e, ProfileGraphEvent): + for ent in e.ents: yield (ent.device, ent.name, e.sigs[ent.st_id], e.sigs[ent.en_id], ent.is_copy) + +def get_profile(profile:list[ProfileEvent]): + # start by getting the time diffs + devs = {e.device:(e.comp_tdiff, e.copy_tdiff if e.copy_tdiff is not None else e.comp_tdiff) for e in profile if isinstance(e,ProfileDeviceEvent)} + # map events per device + dev_events:dict[str, list] = {} + min_ts:int|None = None + max_ts:int|None = None + for device, name, ts, en, is_copy in events_to_json(profile): + time_diff = devs[device][is_copy] + st = int(ts+time_diff) + et = st if en is None else int(en+time_diff) + dev_events.setdefault(device,[]).append({"name":name, "ts":st, "dur":et-st}) + if min_ts is None or st < min_ts: min_ts = st + if max_ts is None or et > max_ts: max_ts = et + return json.dumps({"devEvents":dev_events, "st":min_ts, "et":max_ts}).encode("utf-8") # ** HTTP server @@ -126,7 +121,7 @@ class Handler(BaseHTTPRequestHandler): def do_GET(self): ret, status_code, content_type = b"", 200, "text/html" - if (fn:={"/":"index", "/profiler":"perfetto"}.get((url:=urlparse(self.path)).path)): + if (fn:={"/":"index"}.get((url:=urlparse(self.path)).path)): with open(os.path.join(os.path.dirname(__file__), f"{fn}.html"), "rb") as f: ret = f.read() elif self.path.startswith(("/assets/", "/js/")) and '/..' not in self.path: try: @@ -151,7 +146,7 @@ class Handler(BaseHTTPRequestHandler): # pass if client closed connection except (BrokenPipeError, ConnectionResetError): return ret, content_type = json.dumps(ctxs).encode(), "application/json" - elif url.path == "/get_profile" and perfetto_profile is not None: ret, content_type = perfetto_profile, "application/json" + elif url.path == "/get_profile" and profile_ret is not None: ret, content_type = profile_ret, "application/json" else: status_code = 404 # send response @@ -197,7 +192,7 @@ if __name__ == "__main__": # NOTE: this context is a tuple of list[keys] and list[values] ctxs = get_metadata(*contexts[:2]) if contexts is not None else [] - perfetto_profile = to_perfetto(profile) if profile is not None else None + profile_ret = get_profile(profile) if profile is not None else None server = TCPServerWithReuse(('', PORT), Handler) reloader_thread = threading.Thread(target=reloader)