mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
viz bytepack format (#11792)
* viz bytepack format Training a 1B llama yields ~20M profiler events. With JSON serialization, the browser tries to load 6GB to memory. This OOMs since each tab is limited to <3-4GB memory usage. Using a packed format, we only need ~600MB. **Design decisions:** - Timestamps are in microseconds relative to start time. They're stored in u32, which can express up to ~1 hr of trace events. - Strings (kernel names, metadata, etc) are deduped. - Buffer sizes are in u64 nbytes. More optimization possible: - The string lookup is a JSON dumped array, we can compress this. - Can store less for memory by moving the layout to client. **Results** | | Events | JSON | bytepack | |----------------|---------|-------------|-------------| | DP=8 llama 1B train (`command: [1]`) | 24M | 5.8GB | 640MB | | examples/beautiful_mnist.py | 16K | 3.7MB | 745KB | | examples/gpt2.py | 55K | 12.54MB | 1.40MB | `[1]`: `VIZ=1 FAKEDATA=1 OFFLOAD_OPTIM=1 DP=8 BS=8 GRADIENT_ACC_STEPS=2 BLOCK_REORDER=0 LR=3e-4 TRAIN_ON_VAL=1 DEFAULT_FLOAT=bfloat16 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B WARMUP_STEPS=36 DECAY_STEPS=360 SEQLEN=8192 PYTHONPATH=. AMD=1 AMD_LLVM=0 MODEL=llama3 python3 examples/mlperf/model_train.py` * python reference decoder * 27 bytes / event, 1hr hard limit
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import unittest, decimal, json
|
||||
import unittest, decimal, json, struct
|
||||
from dataclasses import dataclass
|
||||
|
||||
from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, TrackedPatternMatcher
|
||||
@@ -252,9 +252,41 @@ class TestVizIntegration(BaseTestViz):
|
||||
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry
|
||||
from tinygrad.viz.serve import get_profile
|
||||
|
||||
class TinyUnpacker:
|
||||
def __init__(self, buf): self.buf, self.offset = buf, 0
|
||||
def __call__(self, fmt:str) -> tuple:
|
||||
ret = struct.unpack_from(fmt, self.buf, self.offset)
|
||||
self.offset += struct.calcsize(fmt)
|
||||
return ret
|
||||
|
||||
# 0 means None, otherwise it's an enum value
|
||||
def option(i:int) -> int|None: return None if i == 0 else i-1
|
||||
|
||||
def load_profile(lst:list[ProfileEvent]) -> dict:
|
||||
ret = get_profile(lst)
|
||||
return json.loads(ret)
|
||||
u = TinyUnpacker(ret)
|
||||
dur, global_peak, index_len, layout_len = u("<IQII")
|
||||
strings, dtypes = json.loads(ret[u.offset:u.offset+index_len]).values()
|
||||
u.offset += index_len
|
||||
layout:dict[str, dict] = {}
|
||||
for _ in range(layout_len):
|
||||
klen = u("<B")[0]
|
||||
k = ret[u.offset:u.offset+klen].decode()
|
||||
u.offset += klen
|
||||
layout[k] = v = {"shapes":[]}
|
||||
event_type, event_count = u("<BI")
|
||||
if event_type == 0:
|
||||
v["max_depth"] = u("<B")
|
||||
for _ in range(event_count):
|
||||
name, ref, st, dur, depth, cat, _ = u("<IIIfBBI")
|
||||
v["shapes"].append({"name":strings[name], "ref":option(ref), "st":st, "dur":dur, "depth":depth, "cat":option(cat)})
|
||||
else:
|
||||
v["peak"] = u("<Q")[0]
|
||||
v["timestamps"] = list(u(f"<{u('I')[0]}I"))
|
||||
for _ in range(event_count):
|
||||
i = u("<I")[0]
|
||||
v["shapes"].append({"x":list(u(f"<{i}I")), "y":list(u(f"<{i}Q")), "arg": {"dtype":strings[u("<I")[0]], "sz":u("<Q")[0]}})
|
||||
return {"dur":dur, "peak":global_peak, "layout":layout}
|
||||
|
||||
class TestVizProfiler(unittest.TestCase):
|
||||
def test_perfetto_node(self):
|
||||
@@ -269,6 +301,7 @@ class TestVizProfiler(unittest.TestCase):
|
||||
self.assertEqual(event['name'], 'E_2')
|
||||
self.assertEqual(event['st'], 0)
|
||||
self.assertEqual(event['dur'], 10)
|
||||
assert event['ref'] is None
|
||||
|
||||
def test_perfetto_copy_node(self):
|
||||
prof = [ProfileRangeEvent(device='NV', name='COPYxx', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=True),
|
||||
@@ -298,8 +331,8 @@ class TestVizProfiler(unittest.TestCase):
|
||||
|
||||
tracks = list(j['layout'])
|
||||
self.assertEqual(tracks[0], 'NV Graph')
|
||||
self.assertEqual(tracks[2], 'NV')
|
||||
self.assertEqual(tracks[4], 'NV:1')
|
||||
self.assertEqual(tracks[1], 'NV')
|
||||
self.assertEqual(tracks[2], 'NV:1')
|
||||
|
||||
nv_events = j['layout']['NV']['shapes']
|
||||
self.assertEqual(nv_events[0]['name'], 'E_25_4n2')
|
||||
@@ -321,7 +354,16 @@ class TestVizProfiler(unittest.TestCase):
|
||||
n_events = 1_000
|
||||
prof = [ProfileRangeEvent("CPU", name="k_test", st=decimal.Decimal(ts:=i*step), en=decimal.Decimal(ts)+step) for i in range(n_events)]
|
||||
sz = len(get_profile(prof))
|
||||
self.assertLessEqual(sz/n_events, 100)
|
||||
self.assertLessEqual(sz/n_events, 27)
|
||||
|
||||
# can pack up to 1hr 11 min of trace events
|
||||
def test_trace_duration(self):
|
||||
dur_mins = 72
|
||||
n_events = 1_000
|
||||
step = decimal.Decimal(dur_mins*60*1e6//n_events)
|
||||
prof = [ProfileRangeEvent("CPU", name="k_test", st=decimal.Decimal(ts:=i*step), en=decimal.Decimal(ts)+step) for i in range(n_events)]
|
||||
with self.assertRaises(struct.error):
|
||||
get_profile(prof)
|
||||
|
||||
def _alloc(b:int):
|
||||
a = Tensor.empty(b, device="NULL", dtype=dtypes.char)
|
||||
|
||||
@@ -151,7 +151,17 @@ async function renderProfiler() {
|
||||
// layout once!
|
||||
if (data != null) return;
|
||||
const profiler = d3.select(".profiler").html("");
|
||||
const { layout, dur, peak, dtypes } = await (await fetch("/get_profile")).json();
|
||||
const buf = await (await fetch("/get_profile")).arrayBuffer();
|
||||
const view = new DataView(buf);
|
||||
let offset = 0;
|
||||
const u8 = () => { const ret = view.getUint8(offset); offset += 1; return ret; }
|
||||
const u32 = () => { const ret = view.getUint32(offset, true); offset += 4; return ret; }
|
||||
const u64 = () => { const ret = new Number(view.getBigUint64(offset, true)); offset += 8; return ret; }
|
||||
const f32 = () => { const ret = view.getFloat32(offset, true); offset += 4; return ret; }
|
||||
const optional = (i) => i === 0 ? null : i-1;
|
||||
const dur = u32(), peak = u64(), indexLen = u32(), layoutsLen = u32();
|
||||
const textDecoder = new TextDecoder("utf-8");
|
||||
const { strings, dtypes } = JSON.parse(textDecoder.decode(new Uint8Array(buf, offset, indexLen))); offset += indexLen;
|
||||
// place devices on the y axis and set vertical positions
|
||||
const [tickSize, padding] = [10, 8];
|
||||
const deviceList = profiler.append("div").attr("id", "device-list").style("padding-top", tickSize+padding+"px");
|
||||
@@ -164,17 +174,22 @@ async function renderProfiler() {
|
||||
const colorMap = new Map();
|
||||
data = {tracks:new Map(), axes:{}};
|
||||
const heightScale = d3.scaleLinear().domain([0, peak]).range([4,maxheight=100]);
|
||||
for (const [k, v] of Object.entries(layout)) {
|
||||
if (v.shapes.length === 0) continue;
|
||||
for (let i=0; i<layoutsLen; i++) {
|
||||
const nameLen = view.getUint8(offset, true); offset += 1;
|
||||
const k = textDecoder.decode(new Uint8Array(buf, offset, nameLen)); offset += nameLen;
|
||||
const div = deviceList.append("div").attr("id", k).text(k).style("padding", padding+"px");
|
||||
const { y:baseY, height:baseHeight } = rect(div.node());
|
||||
const offsetY = baseY-canvasTop+padding/2;
|
||||
if (v.shapes[0].dur != null) {
|
||||
const EventTypes = {TIMELINE:0, MEMORY:1};
|
||||
const eventType = u8(), eventsLen = u32();
|
||||
if (eventType === EventTypes.TIMELINE) {
|
||||
const maxDepth = u8();
|
||||
const levelHeight = baseHeight-padding;
|
||||
const shapes = [];
|
||||
data.tracks.set(k, { shapes, offsetY });
|
||||
let colorKey, ref;
|
||||
for (const e of v.shapes) {
|
||||
for (let j=0; j<eventsLen; j++) {
|
||||
const e = {name:strings[u32()], ref:optional(u32()), st:u32(), dur:f32(), depth:u8(), cat:optional(u8()), info:strings[u32()] || null};
|
||||
if (e.depth === 0) colorKey = e.cat ?? e.name;
|
||||
if (!colorMap.has(colorKey)) colorMap.set(colorKey, cycleColors(colorScheme[k] ?? colorScheme.DEFAULT, colorMap.size));
|
||||
const fillColor = d3.color(colorMap.get(colorKey)).brighter(e.depth).toString();
|
||||
@@ -189,18 +204,22 @@ async function renderProfiler() {
|
||||
// offset y by depth
|
||||
shapes.push({x:e.st, y:levelHeight*e.depth, width:e.dur, height:levelHeight, arg, label, fillColor });
|
||||
}
|
||||
div.style("height", levelHeight*v.maxDepth+padding+"px").style("pointerEvents", "none");
|
||||
div.style("height", levelHeight*maxDepth+padding+"px").style("pointerEvents", "none");
|
||||
} else {
|
||||
const height = heightScale(v.peak);
|
||||
const yscale = d3.scaleLinear().domain([0, v.peak]).range([height, 0]);
|
||||
const peak = u64();
|
||||
const height = heightScale(peak);
|
||||
const yscale = d3.scaleLinear().domain([0, peak]).range([height, 0]);
|
||||
const timestamps = Array.from({length:u32()}, u32);
|
||||
const shapes = [];
|
||||
for (const [i,e] of v.shapes.entries()) {
|
||||
const x = e.x.map(tsIdx => v.timestamps[tsIdx]);
|
||||
for (let j=0; j<eventsLen; j++) {
|
||||
const length = u32();
|
||||
const x = Array.from({ length }, () => timestamps[u32()]);
|
||||
const e = {y:Array.from({ length }, u64), arg:{dtype:strings[u32()], sz:u64()}};
|
||||
const nbytes = dtypes[e.arg.dtype]*e.arg.sz;
|
||||
const arg = {tooltipText:`${e.arg.dtype} len:${formatUnit(e.arg.sz)}\n${formatUnit(nbytes, "B")}`};
|
||||
shapes.push({ x, y0:e.y.map(yscale), y1:e.y.map(y => yscale(y+nbytes)), arg, fillColor:cycleColors(colorScheme.BUFFER, i) });
|
||||
const arg = {tooltipText:`${e.arg.dtype} len:${formatUnit(e.arg.sz)}\n${formatUnit(e.arg.nbytes, "B")}`};
|
||||
shapes.push({ x, y0:e.y.map(yscale), y1:e.y.map(y => yscale(y+nbytes)), arg, fillColor:cycleColors(colorScheme.BUFFER, j) });
|
||||
}
|
||||
data.tracks.set(k, { shapes, offsetY, height, peak:v.peak, scaleFactor:maxheight*4/height });
|
||||
data.tracks.set(k, { shapes, offsetY, height, peak, scaleFactor:maxheight*4/height });
|
||||
div.style("height", height+padding+"px").style("cursor", "pointer").on("click", (e) => {
|
||||
const newFocus = e.currentTarget.id === focusedDevice ? null : e.currentTarget.id;
|
||||
let offset = 0;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, socketserver, functools, codecs, io
|
||||
import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, socketserver, functools, codecs, io, struct
|
||||
import subprocess, ctypes
|
||||
from contextlib import redirect_stdout
|
||||
from decimal import Decimal
|
||||
@@ -106,6 +106,15 @@ def get_details(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None,
|
||||
"diff":list(difflib.unified_diff(str(u0).splitlines(), str(u1).splitlines())), "upat":(upat_loc, printable(upat_loc))}
|
||||
if not ctx.bottom_up: next_sink = new_sink
|
||||
|
||||
# encoder helpers
|
||||
|
||||
def enum_str(s, cache:dict[str, int]) -> int:
|
||||
if (cret:=cache.get(s)) is not None: return cret
|
||||
cache[s] = ret = len(cache)
|
||||
return ret
|
||||
|
||||
def option(s:int|None) -> int: return 0 if s is None else s+1
|
||||
|
||||
# Profiler API
|
||||
|
||||
device_ts_diffs:dict[str, tuple[Decimal, Decimal]] = {}
|
||||
@@ -123,10 +132,11 @@ def flatten_events(profile:list[ProfileEvent]) -> Generator[tuple[Decimal, Decim
|
||||
for i,ent in enumerate(e.ents): yield (cpu_ts[i*2], cpu_ts[i*2+1], ent)
|
||||
|
||||
# timeline layout stacks events in a contiguous block. When a late starter finishes late, there is whitespace in the higher levels.
|
||||
def timeline_layout(events:list[tuple[int, int, float, DevEvent]], start_ts:int) -> dict:
|
||||
shapes:list[dict] = []
|
||||
def timeline_layout(events:list[tuple[int, int, float, DevEvent]], start_ts:int, scache:dict[str, int]) -> bytes|None:
|
||||
shapes:list[bytes] = []
|
||||
levels:list[int] = []
|
||||
exec_points:dict[str, dict] = {}
|
||||
category_enum:dict[str, int] = {}
|
||||
for st,et,dur,e in events:
|
||||
if isinstance(e, ProfilePointEvent) and e.name == "exec": exec_points[e.key] = e.arg
|
||||
if dur == 0: continue
|
||||
@@ -143,10 +153,12 @@ def timeline_layout(events:list[tuple[int, int, float, DevEvent]], start_ts:int)
|
||||
elif isinstance(e.name, TracingKey):
|
||||
name, cat = e.name.display_name, e.name.cat
|
||||
ref = next((v for k in e.name.keys if (v:=ref_map.get(k)) is not None), None)
|
||||
shapes.append({"name":name, "ref":ref, "st":st-start_ts, "dur":dur, "depth":depth, "cat":cat, "info":info})
|
||||
return {"shapes":shapes, "maxDepth":len(levels)}
|
||||
shapes.append(struct.pack("<IIIfBBI", enum_str(name,scache), option(ref), st-start_ts, dur, depth,
|
||||
option(None if cat is None else enum_str(cat, category_enum)), enum_str(info or "",scache)))
|
||||
return struct.pack("<BIB", 0, len(shapes), len(levels))+b"".join(shapes) if shapes else None
|
||||
|
||||
def mem_layout(events:list[tuple[int, int, float, DevEvent]], start_ts:int, end_ts:int, peaks:list[int], dtypes_map:dict[str, int]) -> dict:
|
||||
def mem_layout(events:list[tuple[int, int, float, DevEvent]], start_ts:int, end_ts:int, peaks:list[int], dtypes_map:dict[str, int],
|
||||
scache:dict[str, int]) -> bytes|None:
|
||||
step, peak, mem = 0, 0, 0
|
||||
shps:dict[int, dict] = {}
|
||||
temp:dict[int, dict] = {}
|
||||
@@ -175,7 +187,9 @@ def mem_layout(events:list[tuple[int, int, float, DevEvent]], start_ts:int, end_
|
||||
v["y"].append(v["y"][-1])
|
||||
timestamps.append(end_ts-start_ts)
|
||||
peaks.append(peak)
|
||||
return {"shapes":list(shps.values()), "peak":peak, "timestamps":timestamps}
|
||||
bufs = [struct.pack("<I"+str(i:=len(v['x']))+f"I{i}QIQ", i, *v["x"], *v["y"], enum_str(v["arg"]["dtype"], scache),
|
||||
v["arg"]["sz"]) for v in shps.values()]
|
||||
return struct.pack("<BIQI", 1, len(shps), peak, len(timestamps))+struct.pack(f"<{len(timestamps)}I", *timestamps)+b"".join(bufs) if bufs else None
|
||||
|
||||
def get_profile(profile:list[ProfileEvent]) -> bytes|None:
|
||||
# start by getting the time diffs
|
||||
@@ -191,14 +205,17 @@ def get_profile(profile:list[ProfileEvent]) -> bytes|None:
|
||||
if end_ts is None or et > end_ts: end_ts = et
|
||||
if start_ts is None: return None
|
||||
# return layout of per device events
|
||||
layout:dict[str, dict] = {}
|
||||
layout:dict[str, bytes|None] = {}
|
||||
scache:dict[str, int] = {}
|
||||
peaks:list[int] = []
|
||||
dtypes_map:dict[str, int] = {}
|
||||
for k,v in dev_events.items():
|
||||
v.sort(key=lambda e:e[0])
|
||||
layout[k] = timeline_layout(v, start_ts)
|
||||
layout[f"{k} Memory"] = mem_layout(v, start_ts, unwrap(end_ts), peaks, dtypes_map)
|
||||
return json.dumps({"layout":layout, "dur":unwrap(end_ts)-start_ts, "peak":max(peaks, default=0), "dtypes":dtypes_map}).encode("utf-8")
|
||||
layout[k] = timeline_layout(v, start_ts, scache)
|
||||
layout[f"{k} Memory"] = mem_layout(v, start_ts, unwrap(end_ts), peaks, dtypes_map, scache)
|
||||
ret = [b"".join([struct.pack("<B", len(k)), k.encode(), v]) for k,v in layout.items() if v is not None]
|
||||
index = json.dumps({"strings":list(scache), "dtypes":dtypes_map}).encode()
|
||||
return struct.pack("<IQII", unwrap(end_ts)-start_ts, max(peaks,default=0), len(index), len(ret))+index+b"".join(ret)
|
||||
|
||||
def get_runtime_stats(key) -> list[dict]:
|
||||
ret:list[dict] = []
|
||||
|
||||
Reference in New Issue
Block a user