viz: renames and spacing changes to tracing (#11102)

This commit is contained in:
qazal
2025-07-05 18:40:39 +03:00
committed by GitHub
parent 7619bf35e7
commit 81781dc12b
4 changed files with 15 additions and 14 deletions

View File

@@ -105,13 +105,13 @@ class TestViz(unittest.TestCase):
# name can also come from a function that returns a TracingKey
def test_tracing_key(self):
@track_rewrites(name=lambda inp,ret: TracingKey("custom_name", fmt=f"input={inp.render()}"))
@track_rewrites(name=lambda inp,ret: TracingKey("custom_name", (inp,), fmt=f"input={inp.render()}"))
def test(s:UOp): return graph_rewrite(s, PatternMatcher([]))
test(UOp.variable("a", 1, 10)+1)
lst = get_viz_list()
# NOTE: names from TracingKey do not get deduped
self.assertEqual(lst[0]["name"], "custom_name")
self.assertEqual(lst[0]["kernel_code"], "input=(a+1)")
self.assertEqual(lst[0]["fmt"], "input=(a+1)")
def test_colored_label(self):
# NOTE: dataclass repr prints literal escape codes instead of unicode chars

View File

@@ -759,21 +759,22 @@ def track_uop(u:UOp):
VIZ = ContextVar("VIZ", 0)
TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if VIZ else 0)
match_stats:dict[UPat, list[Union[int, float]]] = dict()
@dataclass(frozen=True)
class TrackedGraphRewrite:
loc: tuple[str, int] # location that called graph_rewrite
sink: int # the sink input to graph_rewrite
matches: list[tuple[int, int, tuple]] # before/after UOp, UPat location
name: str|None # optional name of the rewrite
depth: int # depth if it's a subrewrite
bottom_up: bool
loc:tuple[str, int] # location that called graph_rewrite
sink:int # the sink input to graph_rewrite
matches:list[tuple[int, int, tuple]] # before/after UOp, UPat location
name:str|None # optional name of the rewrite
depth:int # depth if it's a subrewrite
bottom_up:bool
@dataclass(frozen=True)
class TracingKey:
display_name:str # display name of this trace event
keys:tuple[str, ...]=() # optional keys to search for related traces
fmt:str|None=None # optional detailed formatting
cat:str|None=None # optional category to color this by
display_name:str # display name of this trace event
keys:tuple[str, ...]=() # optional keys to search for related traces
fmt:str|None=None # optional detailed formatting
cat:str|None=None # optional category to color this by
tracked_keys:list[Any] = []
tracked_ctxs:list[list[TrackedGraphRewrite]] = []

View File

@@ -478,7 +478,7 @@ async function main() {
renderDag(ret[currentRewrite].graph, ret[currentRewrite].changed_nodes || [], recenter=currentRewrite === 0);
// ** right sidebar code blocks
const metadata = document.querySelector(".metadata");
const [code, lang] = ctx.kernel_code != null ? [ctx.kernel_code, "cpp"] : [ret[currentRewrite].uop, "python"];
const [code, lang] = ctx.fmt != null ? [ctx.fmt, "cpp"] : [ret[currentRewrite].uop, "python"];
metadata.replaceChildren(codeBlock(step.code_line, "python", { loc:step.loc, wrap:true }), codeBlock(code, lang, { wrap:false }));
// ** rewrite steps
if (step.match_count >= 1) {

View File

@@ -25,8 +25,8 @@ def get_metadata(keys:list[TracingKey], contexts:list[list[TrackedGraphRewrite]]
ret = []
for i,(k,v) in enumerate(zip(keys, contexts)):
steps = [{"name":s.name, "loc":s.loc, "depth":s.depth, "match_count":len(s.matches), "code_line":printable(s.loc)} for s in v]
ret.append({"name":k.display_name, "fmt":k.fmt, "steps":steps})
for key in k.keys: ref_map[key] = i
ret.append({"name":k.display_name, "kernel_code":k.fmt, "steps":steps})
return ret
# ** Complete rewrite details for a graph_rewrite call