mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 05:18:01 -05:00
TRACEMETA>=2 displays UOp metadata in VIZ (#10237)
This commit is contained in:
@@ -225,10 +225,11 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s
|
||||
|
||||
class UOpMetaClass(type):
|
||||
ucache:dict[tuple, weakref.ReferenceType[UOp]] = {}
|
||||
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, _buffer:Buffer|None=None):
|
||||
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, metadata:Metadata|None=None, _buffer:Buffer|None=None):
|
||||
if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None and (ret:=wret()) is not None: return ret
|
||||
UOpMetaClass.ucache[key] = ref = weakref.ref(created:=super().__call__(*key))
|
||||
for s in src: s.children.add(ref)
|
||||
if metadata is not None: all_metadata[created] = metadata
|
||||
# NOTE: this value is set by pickle when pickling a realized tensor
|
||||
if _buffer is not None:
|
||||
assert op is Ops.BUFFER, f"trying to set Buffer {_buffer} for {op}"
|
||||
@@ -254,6 +255,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
del UOpMetaClass.ucache[k]
|
||||
def __reduce__(self):
|
||||
args = [self.op, self.dtype, self.src, self.arg]
|
||||
args.append(self.metadata)
|
||||
if self.op is Ops.BUFFER and self.realized is not None and PICKLE_BUFFERS: args.append(self.realized)
|
||||
return UOp, tuple(args)
|
||||
def replace(self, **kwargs) -> UOp:
|
||||
|
||||
@@ -3,7 +3,7 @@ import multiprocessing, pickle, functools, difflib, os, threading, json, time, s
|
||||
from http.server import BaseHTTPRequestHandler
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from typing import Any, Callable, TypedDict, Generator
|
||||
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap
|
||||
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA
|
||||
from tinygrad.ops import TrackedGraphRewrite, UOp, Ops, lines, GroupOp, srender, sint
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent
|
||||
@@ -83,6 +83,8 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
||||
label += f"\n{shape_to_str(u.shape)}"
|
||||
except Exception:
|
||||
label += "\n<ISSUE GETTING SHAPE>"
|
||||
# NOTE: kernel already has metadata in arg
|
||||
if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.KERNEL: label += "\n"+repr(u.metadata)
|
||||
graph[id(u)] = {"label":label, "src":[id(x) for x in u.src if x not in excluded], "color":uops_colors.get(u.op, "#ffffff")}
|
||||
return graph
|
||||
|
||||
|
||||
Reference in New Issue
Block a user