diff --git a/tinygrad/device.py b/tinygrad/device.py index 4b26bae3fb..721da7e50e 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -7,6 +7,7 @@ from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, cpu_time_execution, colored, Context, round_up, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype from tinygrad.renderer import Renderer +from tinygrad.uop.ops import TracingKey # **************** Device **************** @@ -56,8 +57,8 @@ class ProfileEvent: pass class ProfileDeviceEvent(ProfileEvent): device:str; comp_tdiff:decimal.Decimal=decimal.Decimal(0); copy_tdiff:decimal.Decimal=decimal.Decimal(0) # noqa: E702 -@dataclass(frozen=True) -class ProfileRangeEvent(ProfileEvent): device:str; name:str; st:decimal.Decimal; en:decimal.Decimal; is_copy:bool # noqa: E702 +@dataclass +class ProfileRangeEvent(ProfileEvent): device:str; name:str|TracingKey; st:decimal.Decimal; en:decimal.Decimal|None=None; is_copy:bool=False # noqa: E702 @dataclass(frozen=True) class ProfilePointEvent(ProfileEvent): device:str; name:str; st:decimal.Decimal; ref:int; arg:dict=field(default_factory=dict) # noqa: E702 @@ -71,17 +72,13 @@ class ProfileGraphEntry: device:str; name:str; st_id:int; en_id:int; is_copy:boo @dataclass(frozen=True) class ProfileGraphEvent(ProfileEvent): ents:list[ProfileGraphEntry]; deps:list[list[int]]; sigs:list[decimal.Decimal] # noqa: E702 -@dataclass -class ProfileResult: st:Optional[int]=None; en:Optional[int]=None # noqa: E702 - @contextlib.contextmanager -def cpu_profile(name, device="CPU", is_copy=False, display=True) -> Generator[ProfileResult, None, None]: - res = ProfileResult(st:=time.perf_counter_ns()) +def cpu_profile(name, device="CPU", is_copy=False, display=True) -> Generator[ProfileRangeEvent, None, None]: + res = ProfileRangeEvent(device, name, decimal.Decimal(time.perf_counter_ns()) / 1000, is_copy=is_copy) try: yield res finally: - res.en = en = time.perf_counter_ns() - if PROFILE and display: - Compiled.profile_events += [ProfileRangeEvent(device, name, decimal.Decimal(st) / 1000, decimal.Decimal(en) / 1000, is_copy=is_copy)] + res.en = decimal.Decimal(time.perf_counter_ns()) / 1000 + if PROFILE and display: Compiled.profile_events.append(res) # **************** Buffer + Allocators **************** diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 4ad95d3dc5..f310324634 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -767,6 +767,13 @@ class TrackedGraphRewrite: name: str|None # optional name of the rewrite depth: int # depth if it's a subrewrite bottom_up: bool + +@dataclass(frozen=True) +class TracingKey: + display_name:str # display name of this trace event + keys:tuple[str, ...]=() # optional keys to search for related traces + cat:str|None=None # optional category to color this by + tracked_keys:list[Any] = [] tracked_ctxs:list[list[TrackedGraphRewrite]] = [] _name_cnt:dict[str, itertools.count] = {} @@ -784,10 +791,15 @@ def track_rewrites(name:Callable|bool=True): if TRACK_MATCH_STATS >= 2: tracked_keys.append((fn:=func.__name__)+f" n{next(_name_cnt.setdefault(fn, itertools.count(1)))}") tracked_ctxs.append([]) - ret = func(*args, **kwargs) + # late import! + from tinygrad.device import cpu_profile + with cpu_profile(func.__name__, "TINY") as e: + ret = func(*args, **kwargs) if TRACK_MATCH_STATS >= 2 and callable(name): name_ret = name(*args, **kwargs, ret=ret) - tracked_keys[-1] = tracked_keys[-1].replace(fn, name_ret) if isinstance(name_ret, str) else name_ret + tracked_keys[-1] = key = tracked_keys[-1].replace(fn, name_ret) if isinstance(name_ret, str) else name_ret + if isinstance(key, str): e.name = TracingKey(key, (key,), func.__name__) + else: e.name = TracingKey(f"{func.__name__} for {name_ret.name}", (name_ret.name,), func.__name__) if getenv("CAPTURE_PROCESS_REPLAY"): # find the unittest frame we're capturing in frm = sys._getframe(1) diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index 3b714974f0..b696fafe09 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -108,7 +108,8 @@ function formatTime(ts, dur=ts) { } const formatUnit = (d, unit="") => d3.format(".3~s")(d)+unit; -const colors = ["#1D1F2A", "#2A2D3D", "#373B4F", "#444862", "#12131A", "#2F3244", "#3B3F54", "#4A4E65", "#181A23", "#232532", "#313548", "#404459"]; +const devColors = {"TINY":["#1B5745", "#1D2E62"], + "DEFAULT":["#1D1F2A", "#2A2D3D", "#373B4F", "#444862", "#12131A", "#2F3244", "#3B3F54", "#4A4E65", "#181A23", "#232532", "#313548", "#404459"],} const bufColors = ["#3A57B7","#5066C1","#6277CD","#7488D8","#8A9BE3","#A3B4F2"]; var profileRet, focusedDevice, canvasZoom, zoomLevel = d3.zoomIdentity; @@ -141,12 +142,14 @@ async function renderProfiler() { const levelHeight = baseHeight-padding; const offsetY = baseY-canvasTop+padding/2; for (const [i,e] of timeline.shapes.entries()) { - if (!nameMap.has(e.name)) { - const label = parseColors(e.name).map(({ color, st }) => ({ color, st, width:ctx.measureText(st).width })); - nameMap.set(e.name, { fillColor:colors[i%colors.length], label }); - } - // offset y by depth - data.shapes.push({ x:e.st-st, dur:e.dur, name:e.name, height:levelHeight, y:offsetY+levelHeight*e.depth, ref:e.ref, ...nameMap.get(e.name) }); + const label = parseColors(e.name).map(({ color, st }) => ({ color, st, width:ctx.measureText(st).width })); + const colorKey = e.cat ?? e.name; + if (!nameMap.has(colorKey)) { + const colors = devColors[k] ?? devColors.DEFAULT; + nameMap.set(colorKey, { fillColor:colors[i%colors.length] }); + } + // offset y by depth + data.shapes.push({ x:e.st-st, dur:e.dur, name:e.name, height:levelHeight, y:offsetY+levelHeight*e.depth, ref:e.ref, label, ...nameMap.get(colorKey) }); } // position shapes on the canvas and scale to fit fixed area const startY = offsetY+(levelHeight*timeline.maxDepth)+padding/2; diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 16c471d2f9..49ad8e25f4 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -4,7 +4,7 @@ from http.server import BaseHTTPRequestHandler from urllib.parse import parse_qs, urlparse from typing import Any, TypedDict, Generator from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA -from tinygrad.uop.ops import TrackedGraphRewrite, UOp, Ops, printable, GroupOp, srender, sint +from tinygrad.uop.ops import TrackedGraphRewrite, TracingKey, UOp, Ops, printable, GroupOp, srender, sint from tinygrad.renderer import ProgramSpec from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry, ProfilePointEvent from tinygrad.dtype import dtypes @@ -114,9 +114,12 @@ def timeline_layout(events:list[tuple[int, int, float, DevEvent]]) -> dict: depth = next((i for i,level_et in enumerate(levels) if st>=level_et), len(levels)) if depth < len(levels): levels[depth] = et else: levels.append(et) - name = e.name + name, cat = e.name, None if (ref:=ref_map.get(name)) is not None: name = ctxs[ref]["name"] - shapes.append({"name":name, "ref":ref, "st":st, "dur":dur, "depth":depth}) + elif isinstance(e.name, TracingKey): + name, cat = e.name.display_name, e.name.cat + ref = next((v for k in e.name.keys if (v:=ref_map.get(k)) is not None), None) + shapes.append({"name":name, "ref":ref, "st":st, "dur":dur, "depth":depth, "cat":cat}) return {"shapes":shapes, "maxDepth":len(levels)} def mem_layout(events:list[tuple[int, int, float, DevEvent]]) -> dict: