From 2407fecdae1bc6c59eca147d099173471f2abd52 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sat, 23 Aug 2025 23:50:21 +0300 Subject: [PATCH] viz bytepack format (#11792) * viz bytepack format Training a 1B llama yields ~20M profiler events. With JSON serialization, the browser tries to load 6GB to memory. This OOMs since each tab is limited to <3-4GB memory usage. Using a packed format, we only need ~600MB. **Design decisions:** - Timestamps are in microseconds relative to start time. They're stored in u32, which can express up to ~1 hr of trace events. - Strings (kernel names, metadata, etc) are deduped. - Buffer sizes are in u64 nbytes. More optimization possible: - The string lookup is a JSON dumped array, we can compress this. - Can store less for memory by moving the layout to client. **Results** | | Events | JSON | bytepack | |----------------|---------|-------------|-------------| | DP=8 llama 1B train (`command: [1]`) | 24M | 5.8GB | 640MB | | examples/beautiful_mnist.py | 16K | 3.7MB | 745KB | | examples/gpt2.py | 55K | 12.54MB | 1.40MB | `[1]`: `VIZ=1 FAKEDATA=1 OFFLOAD_OPTIM=1 DP=8 BS=8 GRADIENT_ACC_STEPS=2 BLOCK_REORDER=0 LR=3e-4 TRAIN_ON_VAL=1 DEFAULT_FLOAT=bfloat16 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B WARMUP_STEPS=36 DECAY_STEPS=360 SEQLEN=8192 PYTHONPATH=. AMD=1 AMD_LLVM=0 MODEL=llama3 python3 examples/mlperf/model_train.py` * python reference decoder * 27 bytes / event, 1hr hard limit --- test/unit/test_viz.py | 52 ++++++++++++++++++++++++++++++++++++---- tinygrad/viz/js/index.js | 45 ++++++++++++++++++++++++---------- tinygrad/viz/serve.py | 39 +++++++++++++++++++++--------- 3 files changed, 107 insertions(+), 29 deletions(-) diff --git a/test/unit/test_viz.py b/test/unit/test_viz.py index d0b0d632a6..84866f2819 100644 --- a/test/unit/test_viz.py +++ b/test/unit/test_viz.py @@ -1,4 +1,4 @@ -import unittest, decimal, json +import unittest, decimal, json, struct from dataclasses import dataclass from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, TrackedPatternMatcher @@ -252,9 +252,41 @@ class TestVizIntegration(BaseTestViz): from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry from tinygrad.viz.serve import get_profile +class TinyUnpacker: + def __init__(self, buf): self.buf, self.offset = buf, 0 + def __call__(self, fmt:str) -> tuple: + ret = struct.unpack_from(fmt, self.buf, self.offset) + self.offset += struct.calcsize(fmt) + return ret + +# 0 means None, otherwise it's an enum value +def option(i:int) -> int|None: return None if i == 0 else i-1 + def load_profile(lst:list[ProfileEvent]) -> dict: ret = get_profile(lst) - return json.loads(ret) + u = TinyUnpacker(ret) + dur, global_peak, index_len, layout_len = u(" { const ret = view.getUint8(offset); offset += 1; return ret; } + const u32 = () => { const ret = view.getUint32(offset, true); offset += 4; return ret; } + const u64 = () => { const ret = new Number(view.getBigUint64(offset, true)); offset += 8; return ret; } + const f32 = () => { const ret = view.getFloat32(offset, true); offset += 4; return ret; } + const optional = (i) => i === 0 ? null : i-1; + const dur = u32(), peak = u64(), indexLen = u32(), layoutsLen = u32(); + const textDecoder = new TextDecoder("utf-8"); + const { strings, dtypes } = JSON.parse(textDecoder.decode(new Uint8Array(buf, offset, indexLen))); offset += indexLen; // place devices on the y axis and set vertical positions const [tickSize, padding] = [10, 8]; const deviceList = profiler.append("div").attr("id", "device-list").style("padding-top", tickSize+padding+"px"); @@ -164,17 +174,22 @@ async function renderProfiler() { const colorMap = new Map(); data = {tracks:new Map(), axes:{}}; const heightScale = d3.scaleLinear().domain([0, peak]).range([4,maxheight=100]); - for (const [k, v] of Object.entries(layout)) { - if (v.shapes.length === 0) continue; + for (let i=0; i v.timestamps[tsIdx]); + for (let j=0; j timestamps[u32()]); + const e = {y:Array.from({ length }, u64), arg:{dtype:strings[u32()], sz:u64()}}; const nbytes = dtypes[e.arg.dtype]*e.arg.sz; - const arg = {tooltipText:`${e.arg.dtype} len:${formatUnit(e.arg.sz)}\n${formatUnit(nbytes, "B")}`}; - shapes.push({ x, y0:e.y.map(yscale), y1:e.y.map(y => yscale(y+nbytes)), arg, fillColor:cycleColors(colorScheme.BUFFER, i) }); + const arg = {tooltipText:`${e.arg.dtype} len:${formatUnit(e.arg.sz)}\n${formatUnit(e.arg.nbytes, "B")}`}; + shapes.push({ x, y0:e.y.map(yscale), y1:e.y.map(y => yscale(y+nbytes)), arg, fillColor:cycleColors(colorScheme.BUFFER, j) }); } - data.tracks.set(k, { shapes, offsetY, height, peak:v.peak, scaleFactor:maxheight*4/height }); + data.tracks.set(k, { shapes, offsetY, height, peak, scaleFactor:maxheight*4/height }); div.style("height", height+padding+"px").style("cursor", "pointer").on("click", (e) => { const newFocus = e.currentTarget.id === focusedDevice ? null : e.currentTarget.id; let offset = 0; diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 6225cd94e8..ca6d8251b3 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, socketserver, functools, codecs, io +import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, socketserver, functools, codecs, io, struct import subprocess, ctypes from contextlib import redirect_stdout from decimal import Decimal @@ -106,6 +106,15 @@ def get_details(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, "diff":list(difflib.unified_diff(str(u0).splitlines(), str(u1).splitlines())), "upat":(upat_loc, printable(upat_loc))} if not ctx.bottom_up: next_sink = new_sink +# encoder helpers + +def enum_str(s, cache:dict[str, int]) -> int: + if (cret:=cache.get(s)) is not None: return cret + cache[s] = ret = len(cache) + return ret + +def option(s:int|None) -> int: return 0 if s is None else s+1 + # Profiler API device_ts_diffs:dict[str, tuple[Decimal, Decimal]] = {} @@ -123,10 +132,11 @@ def flatten_events(profile:list[ProfileEvent]) -> Generator[tuple[Decimal, Decim for i,ent in enumerate(e.ents): yield (cpu_ts[i*2], cpu_ts[i*2+1], ent) # timeline layout stacks events in a contiguous block. When a late starter finishes late, there is whitespace in the higher levels. -def timeline_layout(events:list[tuple[int, int, float, DevEvent]], start_ts:int) -> dict: - shapes:list[dict] = [] +def timeline_layout(events:list[tuple[int, int, float, DevEvent]], start_ts:int, scache:dict[str, int]) -> bytes|None: + shapes:list[bytes] = [] levels:list[int] = [] exec_points:dict[str, dict] = {} + category_enum:dict[str, int] = {} for st,et,dur,e in events: if isinstance(e, ProfilePointEvent) and e.name == "exec": exec_points[e.key] = e.arg if dur == 0: continue @@ -143,10 +153,12 @@ def timeline_layout(events:list[tuple[int, int, float, DevEvent]], start_ts:int) elif isinstance(e.name, TracingKey): name, cat = e.name.display_name, e.name.cat ref = next((v for k in e.name.keys if (v:=ref_map.get(k)) is not None), None) - shapes.append({"name":name, "ref":ref, "st":st-start_ts, "dur":dur, "depth":depth, "cat":cat, "info":info}) - return {"shapes":shapes, "maxDepth":len(levels)} + shapes.append(struct.pack(" dict: +def mem_layout(events:list[tuple[int, int, float, DevEvent]], start_ts:int, end_ts:int, peaks:list[int], dtypes_map:dict[str, int], + scache:dict[str, int]) -> bytes|None: step, peak, mem = 0, 0, 0 shps:dict[int, dict] = {} temp:dict[int, dict] = {} @@ -175,7 +187,9 @@ def mem_layout(events:list[tuple[int, int, float, DevEvent]], start_ts:int, end_ v["y"].append(v["y"][-1]) timestamps.append(end_ts-start_ts) peaks.append(peak) - return {"shapes":list(shps.values()), "peak":peak, "timestamps":timestamps} + bufs = [struct.pack(" bytes|None: # start by getting the time diffs @@ -191,14 +205,17 @@ def get_profile(profile:list[ProfileEvent]) -> bytes|None: if end_ts is None or et > end_ts: end_ts = et if start_ts is None: return None # return layout of per device events - layout:dict[str, dict] = {} + layout:dict[str, bytes|None] = {} + scache:dict[str, int] = {} peaks:list[int] = [] dtypes_map:dict[str, int] = {} for k,v in dev_events.items(): v.sort(key=lambda e:e[0]) - layout[k] = timeline_layout(v, start_ts) - layout[f"{k} Memory"] = mem_layout(v, start_ts, unwrap(end_ts), peaks, dtypes_map) - return json.dumps({"layout":layout, "dur":unwrap(end_ts)-start_ts, "peak":max(peaks, default=0), "dtypes":dtypes_map}).encode("utf-8") + layout[k] = timeline_layout(v, start_ts, scache) + layout[f"{k} Memory"] = mem_layout(v, start_ts, unwrap(end_ts), peaks, dtypes_map, scache) + ret = [b"".join([struct.pack(" list[dict]: ret:list[dict] = []