TRACEMETA>=2 displays UOp metadata in VIZ (#10237)

This commit is contained in:
qazal
2025-05-09 17:42:00 +03:00
committed by GitHub
parent d5686f33a9
commit e2292f6663
2 changed files with 6 additions and 2 deletions

View File

@@ -225,10 +225,11 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s
class UOpMetaClass(type):
ucache:dict[tuple, weakref.ReferenceType[UOp]] = {}
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, _buffer:Buffer|None=None):
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, metadata:Metadata|None=None, _buffer:Buffer|None=None):
if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None and (ret:=wret()) is not None: return ret
UOpMetaClass.ucache[key] = ref = weakref.ref(created:=super().__call__(*key))
for s in src: s.children.add(ref)
if metadata is not None: all_metadata[created] = metadata
# NOTE: this value is set by pickle when pickling a realized tensor
if _buffer is not None:
assert op is Ops.BUFFER, f"trying to set Buffer {_buffer} for {op}"
@@ -254,6 +255,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
del UOpMetaClass.ucache[k]
def __reduce__(self):
args = [self.op, self.dtype, self.src, self.arg]
args.append(self.metadata)
if self.op is Ops.BUFFER and self.realized is not None and PICKLE_BUFFERS: args.append(self.realized)
return UOp, tuple(args)
def replace(self, **kwargs) -> UOp:

View File

@@ -3,7 +3,7 @@ import multiprocessing, pickle, functools, difflib, os, threading, json, time, s
from http.server import BaseHTTPRequestHandler
from urllib.parse import parse_qs, urlparse
from typing import Any, Callable, TypedDict, Generator
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA
from tinygrad.ops import TrackedGraphRewrite, UOp, Ops, lines, GroupOp, srender, sint
from tinygrad.codegen.kernel import Kernel
from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent
@@ -83,6 +83,8 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
label += f"\n{shape_to_str(u.shape)}"
except Exception:
label += "\n<ISSUE GETTING SHAPE>"
# NOTE: kernel already has metadata in arg
if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.KERNEL: label += "\n"+repr(u.metadata)
graph[id(u)] = {"label":label, "src":[id(x) for x in u.src if x not in excluded], "color":uops_colors.get(u.op, "#ffffff")}
return graph