From e2292f66638b1339e427db081157540bcb618307 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 9 May 2025 17:42:00 +0300 Subject: [PATCH] TRACEMETA>=2 displays UOp metadata in VIZ (#10237) --- tinygrad/ops.py | 4 +++- tinygrad/viz/serve.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index b14bb25b1e..32a66a9f36 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -225,10 +225,11 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s class UOpMetaClass(type): ucache:dict[tuple, weakref.ReferenceType[UOp]] = {} - def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, _buffer:Buffer|None=None): + def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, metadata:Metadata|None=None, _buffer:Buffer|None=None): if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None and (ret:=wret()) is not None: return ret UOpMetaClass.ucache[key] = ref = weakref.ref(created:=super().__call__(*key)) for s in src: s.children.add(ref) + if metadata is not None: all_metadata[created] = metadata # NOTE: this value is set by pickle when pickling a realized tensor if _buffer is not None: assert op is Ops.BUFFER, f"trying to set Buffer {_buffer} for {op}" @@ -254,6 +255,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): del UOpMetaClass.ucache[k] def __reduce__(self): args = [self.op, self.dtype, self.src, self.arg] + args.append(self.metadata) if self.op is Ops.BUFFER and self.realized is not None and PICKLE_BUFFERS: args.append(self.realized) return UOp, tuple(args) def replace(self, **kwargs) -> UOp: diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 2c1fccb37e..58521f22bf 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -3,7 +3,7 @@ import multiprocessing, pickle, functools, difflib, os, threading, json, time, s from http.server import BaseHTTPRequestHandler from urllib.parse import parse_qs, urlparse from typing import Any, Callable, TypedDict, Generator -from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap +from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA from tinygrad.ops import TrackedGraphRewrite, UOp, Ops, lines, GroupOp, srender, sint from tinygrad.codegen.kernel import Kernel from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent @@ -83,6 +83,8 @@ def uop_to_json(x:UOp) -> dict[int, dict]: label += f"\n{shape_to_str(u.shape)}" except Exception: label += "\n" + # NOTE: kernel already has metadata in arg + if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.KERNEL: label += "\n"+repr(u.metadata) graph[id(u)] = {"label":label, "src":[id(x) for x in u.src if x not in excluded], "color":uops_colors.get(u.op, "#ffffff")} return graph