viz bytepack format (#11792)

* viz bytepack format

Training a 1B llama yields ~20M profiler events.

With JSON serialization, the browser tries to load 6GB to memory. This OOMs since each tab is limited to <3-4GB memory usage. Using a packed format, we only need ~600MB.

**Design decisions:**

- Timestamps are in microseconds relative to start time. They're stored in u32, which can express up to ~1 hr of trace events.
- Strings (kernel names, metadata, etc) are deduped.
- Buffer sizes are in u64 nbytes.

More optimization possible:

- The string lookup is a JSON dumped array, we can compress this.
- Can store less for memory by moving the layout to client.

**Results**

|  | Events | JSON | bytepack |
|----------------|---------|-------------|-------------|
| DP=8 llama 1B train (`command: [1]`) | 24M | 5.8GB | 640MB |
| examples/beautiful_mnist.py | 16K | 3.7MB | 745KB |
| examples/gpt2.py | 55K | 12.54MB | 1.40MB |

`[1]`: `VIZ=1 FAKEDATA=1 OFFLOAD_OPTIM=1 DP=8 BS=8 GRADIENT_ACC_STEPS=2 BLOCK_REORDER=0 LR=3e-4 TRAIN_ON_VAL=1 DEFAULT_FLOAT=bfloat16 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B WARMUP_STEPS=36 DECAY_STEPS=360 SEQLEN=8192 PYTHONPATH=. AMD=1 AMD_LLVM=0 MODEL=llama3 python3 examples/mlperf/model_train.py`

* python reference decoder

* 27 bytes / event, 1hr hard limit
This commit is contained in:
qazal
2025-08-23 23:50:21 +03:00
committed by GitHub
parent b12d1d866c
commit 2407fecdae
3 changed files with 107 additions and 29 deletions

View File

@@ -1,4 +1,4 @@
import unittest, decimal, json
import unittest, decimal, json, struct
from dataclasses import dataclass
from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, TrackedPatternMatcher
@@ -252,9 +252,41 @@ class TestVizIntegration(BaseTestViz):
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry
from tinygrad.viz.serve import get_profile
class TinyUnpacker:
def __init__(self, buf): self.buf, self.offset = buf, 0
def __call__(self, fmt:str) -> tuple:
ret = struct.unpack_from(fmt, self.buf, self.offset)
self.offset += struct.calcsize(fmt)
return ret
# 0 means None, otherwise it's an enum value
def option(i:int) -> int|None: return None if i == 0 else i-1
def load_profile(lst:list[ProfileEvent]) -> dict:
ret = get_profile(lst)
return json.loads(ret)
u = TinyUnpacker(ret)
dur, global_peak, index_len, layout_len = u("<IQII")
strings, dtypes = json.loads(ret[u.offset:u.offset+index_len]).values()
u.offset += index_len
layout:dict[str, dict] = {}
for _ in range(layout_len):
klen = u("<B")[0]
k = ret[u.offset:u.offset+klen].decode()
u.offset += klen
layout[k] = v = {"shapes":[]}
event_type, event_count = u("<BI")
if event_type == 0:
v["max_depth"] = u("<B")
for _ in range(event_count):
name, ref, st, dur, depth, cat, _ = u("<IIIfBBI")
v["shapes"].append({"name":strings[name], "ref":option(ref), "st":st, "dur":dur, "depth":depth, "cat":option(cat)})
else:
v["peak"] = u("<Q")[0]
v["timestamps"] = list(u(f"<{u('I')[0]}I"))
for _ in range(event_count):
i = u("<I")[0]
v["shapes"].append({"x":list(u(f"<{i}I")), "y":list(u(f"<{i}Q")), "arg": {"dtype":strings[u("<I")[0]], "sz":u("<Q")[0]}})
return {"dur":dur, "peak":global_peak, "layout":layout}
class TestVizProfiler(unittest.TestCase):
def test_perfetto_node(self):
@@ -269,6 +301,7 @@ class TestVizProfiler(unittest.TestCase):
self.assertEqual(event['name'], 'E_2')
self.assertEqual(event['st'], 0)
self.assertEqual(event['dur'], 10)
assert event['ref'] is None
def test_perfetto_copy_node(self):
prof = [ProfileRangeEvent(device='NV', name='COPYxx', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=True),
@@ -298,8 +331,8 @@ class TestVizProfiler(unittest.TestCase):
tracks = list(j['layout'])
self.assertEqual(tracks[0], 'NV Graph')
self.assertEqual(tracks[2], 'NV')
self.assertEqual(tracks[4], 'NV:1')
self.assertEqual(tracks[1], 'NV')
self.assertEqual(tracks[2], 'NV:1')
nv_events = j['layout']['NV']['shapes']
self.assertEqual(nv_events[0]['name'], 'E_25_4n2')
@@ -321,7 +354,16 @@ class TestVizProfiler(unittest.TestCase):
n_events = 1_000
prof = [ProfileRangeEvent("CPU", name="k_test", st=decimal.Decimal(ts:=i*step), en=decimal.Decimal(ts)+step) for i in range(n_events)]
sz = len(get_profile(prof))
self.assertLessEqual(sz/n_events, 100)
self.assertLessEqual(sz/n_events, 27)
# can pack up to 1hr 11 min of trace events
def test_trace_duration(self):
dur_mins = 72
n_events = 1_000
step = decimal.Decimal(dur_mins*60*1e6//n_events)
prof = [ProfileRangeEvent("CPU", name="k_test", st=decimal.Decimal(ts:=i*step), en=decimal.Decimal(ts)+step) for i in range(n_events)]
with self.assertRaises(struct.error):
get_profile(prof)
def _alloc(b:int):
a = Tensor.empty(b, device="NULL", dtype=dtypes.char)

View File

@@ -151,7 +151,17 @@ async function renderProfiler() {
// layout once!
if (data != null) return;
const profiler = d3.select(".profiler").html("");
const { layout, dur, peak, dtypes } = await (await fetch("/get_profile")).json();
const buf = await (await fetch("/get_profile")).arrayBuffer();
const view = new DataView(buf);
let offset = 0;
const u8 = () => { const ret = view.getUint8(offset); offset += 1; return ret; }
const u32 = () => { const ret = view.getUint32(offset, true); offset += 4; return ret; }
const u64 = () => { const ret = new Number(view.getBigUint64(offset, true)); offset += 8; return ret; }
const f32 = () => { const ret = view.getFloat32(offset, true); offset += 4; return ret; }
const optional = (i) => i === 0 ? null : i-1;
const dur = u32(), peak = u64(), indexLen = u32(), layoutsLen = u32();
const textDecoder = new TextDecoder("utf-8");
const { strings, dtypes } = JSON.parse(textDecoder.decode(new Uint8Array(buf, offset, indexLen))); offset += indexLen;
// place devices on the y axis and set vertical positions
const [tickSize, padding] = [10, 8];
const deviceList = profiler.append("div").attr("id", "device-list").style("padding-top", tickSize+padding+"px");
@@ -164,17 +174,22 @@ async function renderProfiler() {
const colorMap = new Map();
data = {tracks:new Map(), axes:{}};
const heightScale = d3.scaleLinear().domain([0, peak]).range([4,maxheight=100]);
for (const [k, v] of Object.entries(layout)) {
if (v.shapes.length === 0) continue;
for (let i=0; i<layoutsLen; i++) {
const nameLen = view.getUint8(offset, true); offset += 1;
const k = textDecoder.decode(new Uint8Array(buf, offset, nameLen)); offset += nameLen;
const div = deviceList.append("div").attr("id", k).text(k).style("padding", padding+"px");
const { y:baseY, height:baseHeight } = rect(div.node());
const offsetY = baseY-canvasTop+padding/2;
if (v.shapes[0].dur != null) {
const EventTypes = {TIMELINE:0, MEMORY:1};
const eventType = u8(), eventsLen = u32();
if (eventType === EventTypes.TIMELINE) {
const maxDepth = u8();
const levelHeight = baseHeight-padding;
const shapes = [];
data.tracks.set(k, { shapes, offsetY });
let colorKey, ref;
for (const e of v.shapes) {
for (let j=0; j<eventsLen; j++) {
const e = {name:strings[u32()], ref:optional(u32()), st:u32(), dur:f32(), depth:u8(), cat:optional(u8()), info:strings[u32()] || null};
if (e.depth === 0) colorKey = e.cat ?? e.name;
if (!colorMap.has(colorKey)) colorMap.set(colorKey, cycleColors(colorScheme[k] ?? colorScheme.DEFAULT, colorMap.size));
const fillColor = d3.color(colorMap.get(colorKey)).brighter(e.depth).toString();
@@ -189,18 +204,22 @@ async function renderProfiler() {
// offset y by depth
shapes.push({x:e.st, y:levelHeight*e.depth, width:e.dur, height:levelHeight, arg, label, fillColor });
}
div.style("height", levelHeight*v.maxDepth+padding+"px").style("pointerEvents", "none");
div.style("height", levelHeight*maxDepth+padding+"px").style("pointerEvents", "none");
} else {
const height = heightScale(v.peak);
const yscale = d3.scaleLinear().domain([0, v.peak]).range([height, 0]);
const peak = u64();
const height = heightScale(peak);
const yscale = d3.scaleLinear().domain([0, peak]).range([height, 0]);
const timestamps = Array.from({length:u32()}, u32);
const shapes = [];
for (const [i,e] of v.shapes.entries()) {
const x = e.x.map(tsIdx => v.timestamps[tsIdx]);
for (let j=0; j<eventsLen; j++) {
const length = u32();
const x = Array.from({ length }, () => timestamps[u32()]);
const e = {y:Array.from({ length }, u64), arg:{dtype:strings[u32()], sz:u64()}};
const nbytes = dtypes[e.arg.dtype]*e.arg.sz;
const arg = {tooltipText:`${e.arg.dtype} len:${formatUnit(e.arg.sz)}\n${formatUnit(nbytes, "B")}`};
shapes.push({ x, y0:e.y.map(yscale), y1:e.y.map(y => yscale(y+nbytes)), arg, fillColor:cycleColors(colorScheme.BUFFER, i) });
const arg = {tooltipText:`${e.arg.dtype} len:${formatUnit(e.arg.sz)}\n${formatUnit(e.arg.nbytes, "B")}`};
shapes.push({ x, y0:e.y.map(yscale), y1:e.y.map(y => yscale(y+nbytes)), arg, fillColor:cycleColors(colorScheme.BUFFER, j) });
}
data.tracks.set(k, { shapes, offsetY, height, peak:v.peak, scaleFactor:maxheight*4/height });
data.tracks.set(k, { shapes, offsetY, height, peak, scaleFactor:maxheight*4/height });
div.style("height", height+padding+"px").style("cursor", "pointer").on("click", (e) => {
const newFocus = e.currentTarget.id === focusedDevice ? null : e.currentTarget.id;
let offset = 0;

View File

@@ -1,5 +1,5 @@
#!/usr/bin/env python3
import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, socketserver, functools, codecs, io
import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, socketserver, functools, codecs, io, struct
import subprocess, ctypes
from contextlib import redirect_stdout
from decimal import Decimal
@@ -106,6 +106,15 @@ def get_details(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None,
"diff":list(difflib.unified_diff(str(u0).splitlines(), str(u1).splitlines())), "upat":(upat_loc, printable(upat_loc))}
if not ctx.bottom_up: next_sink = new_sink
# encoder helpers
def enum_str(s, cache:dict[str, int]) -> int:
if (cret:=cache.get(s)) is not None: return cret
cache[s] = ret = len(cache)
return ret
def option(s:int|None) -> int: return 0 if s is None else s+1
# Profiler API
device_ts_diffs:dict[str, tuple[Decimal, Decimal]] = {}
@@ -123,10 +132,11 @@ def flatten_events(profile:list[ProfileEvent]) -> Generator[tuple[Decimal, Decim
for i,ent in enumerate(e.ents): yield (cpu_ts[i*2], cpu_ts[i*2+1], ent)
# timeline layout stacks events in a contiguous block. When a late starter finishes late, there is whitespace in the higher levels.
def timeline_layout(events:list[tuple[int, int, float, DevEvent]], start_ts:int) -> dict:
shapes:list[dict] = []
def timeline_layout(events:list[tuple[int, int, float, DevEvent]], start_ts:int, scache:dict[str, int]) -> bytes|None:
shapes:list[bytes] = []
levels:list[int] = []
exec_points:dict[str, dict] = {}
category_enum:dict[str, int] = {}
for st,et,dur,e in events:
if isinstance(e, ProfilePointEvent) and e.name == "exec": exec_points[e.key] = e.arg
if dur == 0: continue
@@ -143,10 +153,12 @@ def timeline_layout(events:list[tuple[int, int, float, DevEvent]], start_ts:int)
elif isinstance(e.name, TracingKey):
name, cat = e.name.display_name, e.name.cat
ref = next((v for k in e.name.keys if (v:=ref_map.get(k)) is not None), None)
shapes.append({"name":name, "ref":ref, "st":st-start_ts, "dur":dur, "depth":depth, "cat":cat, "info":info})
return {"shapes":shapes, "maxDepth":len(levels)}
shapes.append(struct.pack("<IIIfBBI", enum_str(name,scache), option(ref), st-start_ts, dur, depth,
option(None if cat is None else enum_str(cat, category_enum)), enum_str(info or "",scache)))
return struct.pack("<BIB", 0, len(shapes), len(levels))+b"".join(shapes) if shapes else None
def mem_layout(events:list[tuple[int, int, float, DevEvent]], start_ts:int, end_ts:int, peaks:list[int], dtypes_map:dict[str, int]) -> dict:
def mem_layout(events:list[tuple[int, int, float, DevEvent]], start_ts:int, end_ts:int, peaks:list[int], dtypes_map:dict[str, int],
scache:dict[str, int]) -> bytes|None:
step, peak, mem = 0, 0, 0
shps:dict[int, dict] = {}
temp:dict[int, dict] = {}
@@ -175,7 +187,9 @@ def mem_layout(events:list[tuple[int, int, float, DevEvent]], start_ts:int, end_
v["y"].append(v["y"][-1])
timestamps.append(end_ts-start_ts)
peaks.append(peak)
return {"shapes":list(shps.values()), "peak":peak, "timestamps":timestamps}
bufs = [struct.pack("<I"+str(i:=len(v['x']))+f"I{i}QIQ", i, *v["x"], *v["y"], enum_str(v["arg"]["dtype"], scache),
v["arg"]["sz"]) for v in shps.values()]
return struct.pack("<BIQI", 1, len(shps), peak, len(timestamps))+struct.pack(f"<{len(timestamps)}I", *timestamps)+b"".join(bufs) if bufs else None
def get_profile(profile:list[ProfileEvent]) -> bytes|None:
# start by getting the time diffs
@@ -191,14 +205,17 @@ def get_profile(profile:list[ProfileEvent]) -> bytes|None:
if end_ts is None or et > end_ts: end_ts = et
if start_ts is None: return None
# return layout of per device events
layout:dict[str, dict] = {}
layout:dict[str, bytes|None] = {}
scache:dict[str, int] = {}
peaks:list[int] = []
dtypes_map:dict[str, int] = {}
for k,v in dev_events.items():
v.sort(key=lambda e:e[0])
layout[k] = timeline_layout(v, start_ts)
layout[f"{k} Memory"] = mem_layout(v, start_ts, unwrap(end_ts), peaks, dtypes_map)
return json.dumps({"layout":layout, "dur":unwrap(end_ts)-start_ts, "peak":max(peaks, default=0), "dtypes":dtypes_map}).encode("utf-8")
layout[k] = timeline_layout(v, start_ts, scache)
layout[f"{k} Memory"] = mem_layout(v, start_ts, unwrap(end_ts), peaks, dtypes_map, scache)
ret = [b"".join([struct.pack("<B", len(k)), k.encode(), v]) for k,v in layout.items() if v is not None]
index = json.dumps({"strings":list(scache), "dtypes":dtypes_map}).encode()
return struct.pack("<IQII", unwrap(end_ts)-start_ts, max(peaks,default=0), len(index), len(ret))+index+b"".join(ret)
def get_runtime_stats(key) -> list[dict]:
ret:list[dict] = []