mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
viz: add TINY device (#11095)
* viz: add TINY device * replace Any with a proper type * reorder * diff * rename * space * from diff * multiple keys
This commit is contained in:
@@ -7,6 +7,7 @@ from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put,
|
||||
cpu_time_execution, colored, Context, round_up, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE
|
||||
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.uop.ops import TracingKey
|
||||
|
||||
# **************** Device ****************
|
||||
|
||||
@@ -56,8 +57,8 @@ class ProfileEvent: pass
|
||||
class ProfileDeviceEvent(ProfileEvent):
|
||||
device:str; comp_tdiff:decimal.Decimal=decimal.Decimal(0); copy_tdiff:decimal.Decimal=decimal.Decimal(0) # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProfileRangeEvent(ProfileEvent): device:str; name:str; st:decimal.Decimal; en:decimal.Decimal; is_copy:bool # noqa: E702
|
||||
@dataclass
|
||||
class ProfileRangeEvent(ProfileEvent): device:str; name:str|TracingKey; st:decimal.Decimal; en:decimal.Decimal|None=None; is_copy:bool=False # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProfilePointEvent(ProfileEvent): device:str; name:str; st:decimal.Decimal; ref:int; arg:dict=field(default_factory=dict) # noqa: E702
|
||||
@@ -71,17 +72,13 @@ class ProfileGraphEntry: device:str; name:str; st_id:int; en_id:int; is_copy:boo
|
||||
@dataclass(frozen=True)
|
||||
class ProfileGraphEvent(ProfileEvent): ents:list[ProfileGraphEntry]; deps:list[list[int]]; sigs:list[decimal.Decimal] # noqa: E702
|
||||
|
||||
@dataclass
|
||||
class ProfileResult: st:Optional[int]=None; en:Optional[int]=None # noqa: E702
|
||||
|
||||
@contextlib.contextmanager
|
||||
def cpu_profile(name, device="CPU", is_copy=False, display=True) -> Generator[ProfileResult, None, None]:
|
||||
res = ProfileResult(st:=time.perf_counter_ns())
|
||||
def cpu_profile(name, device="CPU", is_copy=False, display=True) -> Generator[ProfileRangeEvent, None, None]:
|
||||
res = ProfileRangeEvent(device, name, decimal.Decimal(time.perf_counter_ns()) / 1000, is_copy=is_copy)
|
||||
try: yield res
|
||||
finally:
|
||||
res.en = en = time.perf_counter_ns()
|
||||
if PROFILE and display:
|
||||
Compiled.profile_events += [ProfileRangeEvent(device, name, decimal.Decimal(st) / 1000, decimal.Decimal(en) / 1000, is_copy=is_copy)]
|
||||
res.en = decimal.Decimal(time.perf_counter_ns()) / 1000
|
||||
if PROFILE and display: Compiled.profile_events.append(res)
|
||||
|
||||
# **************** Buffer + Allocators ****************
|
||||
|
||||
|
||||
@@ -767,6 +767,13 @@ class TrackedGraphRewrite:
|
||||
name: str|None # optional name of the rewrite
|
||||
depth: int # depth if it's a subrewrite
|
||||
bottom_up: bool
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TracingKey:
|
||||
display_name:str # display name of this trace event
|
||||
keys:tuple[str, ...]=() # optional keys to search for related traces
|
||||
cat:str|None=None # optional category to color this by
|
||||
|
||||
tracked_keys:list[Any] = []
|
||||
tracked_ctxs:list[list[TrackedGraphRewrite]] = []
|
||||
_name_cnt:dict[str, itertools.count] = {}
|
||||
@@ -784,10 +791,15 @@ def track_rewrites(name:Callable|bool=True):
|
||||
if TRACK_MATCH_STATS >= 2:
|
||||
tracked_keys.append((fn:=func.__name__)+f" n{next(_name_cnt.setdefault(fn, itertools.count(1)))}")
|
||||
tracked_ctxs.append([])
|
||||
ret = func(*args, **kwargs)
|
||||
# late import!
|
||||
from tinygrad.device import cpu_profile
|
||||
with cpu_profile(func.__name__, "TINY") as e:
|
||||
ret = func(*args, **kwargs)
|
||||
if TRACK_MATCH_STATS >= 2 and callable(name):
|
||||
name_ret = name(*args, **kwargs, ret=ret)
|
||||
tracked_keys[-1] = tracked_keys[-1].replace(fn, name_ret) if isinstance(name_ret, str) else name_ret
|
||||
tracked_keys[-1] = key = tracked_keys[-1].replace(fn, name_ret) if isinstance(name_ret, str) else name_ret
|
||||
if isinstance(key, str): e.name = TracingKey(key, (key,), func.__name__)
|
||||
else: e.name = TracingKey(f"{func.__name__} for {name_ret.name}", (name_ret.name,), func.__name__)
|
||||
if getenv("CAPTURE_PROCESS_REPLAY"):
|
||||
# find the unittest frame we're capturing in
|
||||
frm = sys._getframe(1)
|
||||
|
||||
@@ -108,7 +108,8 @@ function formatTime(ts, dur=ts) {
|
||||
}
|
||||
const formatUnit = (d, unit="") => d3.format(".3~s")(d)+unit;
|
||||
|
||||
const colors = ["#1D1F2A", "#2A2D3D", "#373B4F", "#444862", "#12131A", "#2F3244", "#3B3F54", "#4A4E65", "#181A23", "#232532", "#313548", "#404459"];
|
||||
const devColors = {"TINY":["#1B5745", "#1D2E62"],
|
||||
"DEFAULT":["#1D1F2A", "#2A2D3D", "#373B4F", "#444862", "#12131A", "#2F3244", "#3B3F54", "#4A4E65", "#181A23", "#232532", "#313548", "#404459"],}
|
||||
const bufColors = ["#3A57B7","#5066C1","#6277CD","#7488D8","#8A9BE3","#A3B4F2"];
|
||||
|
||||
var profileRet, focusedDevice, canvasZoom, zoomLevel = d3.zoomIdentity;
|
||||
@@ -141,12 +142,14 @@ async function renderProfiler() {
|
||||
const levelHeight = baseHeight-padding;
|
||||
const offsetY = baseY-canvasTop+padding/2;
|
||||
for (const [i,e] of timeline.shapes.entries()) {
|
||||
if (!nameMap.has(e.name)) {
|
||||
const label = parseColors(e.name).map(({ color, st }) => ({ color, st, width:ctx.measureText(st).width }));
|
||||
nameMap.set(e.name, { fillColor:colors[i%colors.length], label });
|
||||
}
|
||||
// offset y by depth
|
||||
data.shapes.push({ x:e.st-st, dur:e.dur, name:e.name, height:levelHeight, y:offsetY+levelHeight*e.depth, ref:e.ref, ...nameMap.get(e.name) });
|
||||
const label = parseColors(e.name).map(({ color, st }) => ({ color, st, width:ctx.measureText(st).width }));
|
||||
const colorKey = e.cat ?? e.name;
|
||||
if (!nameMap.has(colorKey)) {
|
||||
const colors = devColors[k] ?? devColors.DEFAULT;
|
||||
nameMap.set(colorKey, { fillColor:colors[i%colors.length] });
|
||||
}
|
||||
// offset y by depth
|
||||
data.shapes.push({ x:e.st-st, dur:e.dur, name:e.name, height:levelHeight, y:offsetY+levelHeight*e.depth, ref:e.ref, label, ...nameMap.get(colorKey) });
|
||||
}
|
||||
// position shapes on the canvas and scale to fit fixed area
|
||||
const startY = offsetY+(levelHeight*timeline.maxDepth)+padding/2;
|
||||
|
||||
@@ -4,7 +4,7 @@ from http.server import BaseHTTPRequestHandler
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from typing import Any, TypedDict, Generator
|
||||
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA
|
||||
from tinygrad.uop.ops import TrackedGraphRewrite, UOp, Ops, printable, GroupOp, srender, sint
|
||||
from tinygrad.uop.ops import TrackedGraphRewrite, TracingKey, UOp, Ops, printable, GroupOp, srender, sint
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry, ProfilePointEvent
|
||||
from tinygrad.dtype import dtypes
|
||||
@@ -114,9 +114,12 @@ def timeline_layout(events:list[tuple[int, int, float, DevEvent]]) -> dict:
|
||||
depth = next((i for i,level_et in enumerate(levels) if st>=level_et), len(levels))
|
||||
if depth < len(levels): levels[depth] = et
|
||||
else: levels.append(et)
|
||||
name = e.name
|
||||
name, cat = e.name, None
|
||||
if (ref:=ref_map.get(name)) is not None: name = ctxs[ref]["name"]
|
||||
shapes.append({"name":name, "ref":ref, "st":st, "dur":dur, "depth":depth})
|
||||
elif isinstance(e.name, TracingKey):
|
||||
name, cat = e.name.display_name, e.name.cat
|
||||
ref = next((v for k in e.name.keys if (v:=ref_map.get(k)) is not None), None)
|
||||
shapes.append({"name":name, "ref":ref, "st":st, "dur":dur, "depth":depth, "cat":cat})
|
||||
return {"shapes":shapes, "maxDepth":len(levels)}
|
||||
|
||||
def mem_layout(events:list[tuple[int, int, float, DevEvent]]) -> dict:
|
||||
|
||||
Reference in New Issue
Block a user