From 28c56a783c4678b44e080e6d49d75ab75c799486 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 6 Feb 2026 09:30:58 +0800 Subject: [PATCH] add CallInfo and viz call toggle (#14570) --- tinygrad/gradient.py | 2 +- tinygrad/tensor.py | 2 +- tinygrad/uop/ops.py | 11 ++++++++++- tinygrad/viz/js/index.js | 17 +++++++++++------ tinygrad/viz/js/worker.js | 31 ++++++++++++++++++++++++++++++- 5 files changed, 53 insertions(+), 10 deletions(-) diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index e6e48908f5..b5b9a4f904 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -14,7 +14,7 @@ def reduce_gradient(ctx:UOp, ret:UOp, op:Ops): if op == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],) def call_gradient(ctx:UOp, k:UOp): - if k.arg is not None: return (None,) + k.arg(ctx, k) + if k.arg.grad_fxn is not None: return (None,) + k.arg.grad_fxn(ctx, k) # auto-differentiate the function fxn, args = k.src[0], k.src[1:] params = sorted([x for x in fxn.toposort() if x.op == Ops.PARAM], key=lambda x: x.arg) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index cb9e4c0dbc..eec5eeac41 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -240,7 +240,7 @@ class Tensor(OpMixin): param = UOp.param(slot, self.dtype, self.shape, self.device) return Tensor(param, device=self.device) def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor: - return Tensor(UOp.call(*[t.uop for t in (self,)+lst], fxn=fxn.uop if isinstance(fxn, Tensor) else fxn, arg=grad_fxn), device=self.device) + return Tensor((fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn), device=self.device) def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]: """ diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 1a1ddbeced..044e8de1b5 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -818,7 +818,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass): src = (UOp(Ops.NOOP) if shape is None else shape_to_shape_arg(shape),) + (() if device is None else (UOp(Ops.DEVICE, arg=device),)) return UOp(Ops.PARAM, dtype, src, arg=slot) - def call(*srcs:UOp, fxn:UOp, arg:Any|None) -> UOp: return UOp(Ops.CALL, fxn.dtype, (fxn,)+srcs, arg) + def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=()) -> UOp: + return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata)) def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]: contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs) kernel = UOp(Ops.CUSTOM_KERNEL, src=contig_srcs, arg=CustomKernel(fxn=fxn, grad_fxn=grad_fxn)) @@ -843,6 +844,14 @@ class CustomKernel: def __reduce__(self): return (CustomKernel, (panic,)) def __repr__(self): return f"CustomKernel({id(self.fxn)})" +@dataclass(frozen=True) +class CallInfo: + grad_fxn: Callable|None = None + metadata: tuple[Metadata, ...] = () + # grad_fxn can't be pickled, but metadata can + def __reduce__(self): return (CallInfo, (None, self.metadata)) + def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata})" + @dataclass(frozen=True) class Kernel: ast: UOp diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index d78c73cff7..b9ecb033be 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -738,12 +738,13 @@ window.addEventListener("popstate", (e) => { }); const createToggle = (id, text) => { - const label = d3.create("label").text(text).node(); + const label = d3.create("label").style("display", "block").text(text).node(); const toggle = d3.create("input").attr("type", "checkbox").attr("id", id).property("checked", true).node(); label.prepend(toggle); return { toggle, label }; } -const { toggle, label:toggleLabel } = createToggle("show-indexing", "Show indexing (r)"); +const showIndexing = createToggle("show-indexing", "Show indexing (r)"); +const showCallSrc = createToggle("show-call-src", "Show CALL src (c)"); const showGraph = createToggle("show-graph", "Show graph (g)"); showGraph.toggle.onchange = () => displaySelection(rect("#graph").width > 0 ? "#custom" : "#graph"); @@ -893,11 +894,13 @@ async function main() { // ** center graph const data = ret[currentRewrite]; const render = (opts) => renderDag({ data, opts }, { recenter:currentRewrite === 0 }); - render({ showIndexing:toggle.checked }); - toggle.onchange = (e) => render({ showIndexing:e.target.checked }); + const getOpts = () => ({ showIndexing:showIndexing.toggle.checked, showCallSrc:showCallSrc.toggle.checked }); + render(getOpts()); + showIndexing.toggle.onchange = () => render(getOpts()); + showCallSrc.toggle.onchange = () => render(getOpts()); // ** right sidebar metadata metadata.innerHTML = ""; - if (ckey.includes("rewrites")) metadata.appendChild(toggleLabel); + if (ckey.includes("rewrites")) metadata.append(showIndexing.label, showCallSrc.label); if (step.code_line != null) metadata.appendChild(codeBlock(step.code_line, "python", { loc:step.loc, wrap:true })); if (step.trace) { const trace = d3.create("pre").append("code").classed("hljs", true); @@ -1025,7 +1028,9 @@ document.addEventListener("keydown", (event) => { document.getElementById("zoom-to-fit-btn").click(); } // r key toggles indexing - if (event.key === "r") toggle.click(); + if (event.key === "r") showIndexing.toggle.click(); + // c key toggles CALL src + if (event.key === "c") showCallSrc.toggle.click(); // g key toggles graph if (event.key === "g") showGraph.toggle.click(); }); diff --git a/tinygrad/viz/js/worker.js b/tinygrad/viz/js/worker.js index 3086243a7d..b67c4f3249 100644 --- a/tinygrad/viz/js/worker.js +++ b/tinygrad/viz/js/worker.js @@ -55,13 +55,42 @@ const layoutUOp = (g, { graph, change }, opts) => { for (const [port, s] of src) g.setEdge(s, k, { label: edgeCounts[s] > 1 ? {type:"tag", text:edgeCounts[s]} : {type:"port", text:port}}); if (change?.includes(parseInt(k))) g.setParent(k, "overlay"); } - // optionally hide nodes from the layuot + // optionally hide nodes from the layout if (!opts.showIndexing) { for (const n of g.nodes()) { const node = g.node(n); if (node.label.includes("dtypes.index")) g.removeNode(n); } } + if (!opts.showCallSrc) { + // remove edges from src[0] to CALL nodes, track affected nodes + const disconnected = new Set(); + for (const n of g.nodes()) { + const node = g.node(n); + if (node?.label?.startsWith("CALL\n") || node?.label === "CALL") { + for (const pred of (g.predecessors(n) || [])) { + const edge = g.edge(pred, n); + if (edge?.label?.text === 0) { + g.removeEdge(pred, n); + disconnected.add(pred); + } + } + } + } + // remove nodes that are now disconnected (no successors), only from affected subtree + let changed = true; + while (changed) { + changed = false; + for (const n of disconnected) { + if (!g.hasNode(n)) continue; + if ((g.successors(n) || []).length === 0) { + for (const pred of (g.predecessors(n) || [])) disconnected.add(pred); + g.removeNode(n); + changed = true; + } + } + } + } dagre.layout(g); // remove overlay node if it's empty if (!g.node("overlay")?.width) g.removeNode("overlay");