mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
viz: renames and spacing changes to tracing (#11102)
This commit is contained in:
@@ -105,13 +105,13 @@ class TestViz(unittest.TestCase):
|
||||
|
||||
# name can also come from a function that returns a TracingKey
|
||||
def test_tracing_key(self):
|
||||
@track_rewrites(name=lambda inp,ret: TracingKey("custom_name", fmt=f"input={inp.render()}"))
|
||||
@track_rewrites(name=lambda inp,ret: TracingKey("custom_name", (inp,), fmt=f"input={inp.render()}"))
|
||||
def test(s:UOp): return graph_rewrite(s, PatternMatcher([]))
|
||||
test(UOp.variable("a", 1, 10)+1)
|
||||
lst = get_viz_list()
|
||||
# NOTE: names from TracingKey do not get deduped
|
||||
self.assertEqual(lst[0]["name"], "custom_name")
|
||||
self.assertEqual(lst[0]["kernel_code"], "input=(a+1)")
|
||||
self.assertEqual(lst[0]["fmt"], "input=(a+1)")
|
||||
|
||||
def test_colored_label(self):
|
||||
# NOTE: dataclass repr prints literal escape codes instead of unicode chars
|
||||
|
||||
@@ -759,21 +759,22 @@ def track_uop(u:UOp):
|
||||
VIZ = ContextVar("VIZ", 0)
|
||||
TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if VIZ else 0)
|
||||
match_stats:dict[UPat, list[Union[int, float]]] = dict()
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TrackedGraphRewrite:
|
||||
loc: tuple[str, int] # location that called graph_rewrite
|
||||
sink: int # the sink input to graph_rewrite
|
||||
matches: list[tuple[int, int, tuple]] # before/after UOp, UPat location
|
||||
name: str|None # optional name of the rewrite
|
||||
depth: int # depth if it's a subrewrite
|
||||
bottom_up: bool
|
||||
loc:tuple[str, int] # location that called graph_rewrite
|
||||
sink:int # the sink input to graph_rewrite
|
||||
matches:list[tuple[int, int, tuple]] # before/after UOp, UPat location
|
||||
name:str|None # optional name of the rewrite
|
||||
depth:int # depth if it's a subrewrite
|
||||
bottom_up:bool
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TracingKey:
|
||||
display_name:str # display name of this trace event
|
||||
keys:tuple[str, ...]=() # optional keys to search for related traces
|
||||
fmt:str|None=None # optional detailed formatting
|
||||
cat:str|None=None # optional category to color this by
|
||||
display_name:str # display name of this trace event
|
||||
keys:tuple[str, ...]=() # optional keys to search for related traces
|
||||
fmt:str|None=None # optional detailed formatting
|
||||
cat:str|None=None # optional category to color this by
|
||||
|
||||
tracked_keys:list[Any] = []
|
||||
tracked_ctxs:list[list[TrackedGraphRewrite]] = []
|
||||
|
||||
@@ -478,7 +478,7 @@ async function main() {
|
||||
renderDag(ret[currentRewrite].graph, ret[currentRewrite].changed_nodes || [], recenter=currentRewrite === 0);
|
||||
// ** right sidebar code blocks
|
||||
const metadata = document.querySelector(".metadata");
|
||||
const [code, lang] = ctx.kernel_code != null ? [ctx.kernel_code, "cpp"] : [ret[currentRewrite].uop, "python"];
|
||||
const [code, lang] = ctx.fmt != null ? [ctx.fmt, "cpp"] : [ret[currentRewrite].uop, "python"];
|
||||
metadata.replaceChildren(codeBlock(step.code_line, "python", { loc:step.loc, wrap:true }), codeBlock(code, lang, { wrap:false }));
|
||||
// ** rewrite steps
|
||||
if (step.match_count >= 1) {
|
||||
|
||||
@@ -25,8 +25,8 @@ def get_metadata(keys:list[TracingKey], contexts:list[list[TrackedGraphRewrite]]
|
||||
ret = []
|
||||
for i,(k,v) in enumerate(zip(keys, contexts)):
|
||||
steps = [{"name":s.name, "loc":s.loc, "depth":s.depth, "match_count":len(s.matches), "code_line":printable(s.loc)} for s in v]
|
||||
ret.append({"name":k.display_name, "fmt":k.fmt, "steps":steps})
|
||||
for key in k.keys: ref_map[key] = i
|
||||
ret.append({"name":k.display_name, "kernel_code":k.fmt, "steps":steps})
|
||||
return ret
|
||||
|
||||
# ** Complete rewrite details for a graph_rewrite call
|
||||
|
||||
Reference in New Issue
Block a user