add CallInfo and viz call toggle (#14570)

This commit is contained in:
George Hotz
2026-02-06 09:30:58 +08:00
committed by GitHub
parent f73468d516
commit 28c56a783c
5 changed files with 53 additions and 10 deletions

View File

@@ -14,7 +14,7 @@ def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
if op == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],)
def call_gradient(ctx:UOp, k:UOp):
if k.arg is not None: return (None,) + k.arg(ctx, k)
if k.arg.grad_fxn is not None: return (None,) + k.arg.grad_fxn(ctx, k)
# auto-differentiate the function
fxn, args = k.src[0], k.src[1:]
params = sorted([x for x in fxn.toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)

View File

@@ -240,7 +240,7 @@ class Tensor(OpMixin):
param = UOp.param(slot, self.dtype, self.shape, self.device)
return Tensor(param, device=self.device)
def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor:
return Tensor(UOp.call(*[t.uop for t in (self,)+lst], fxn=fxn.uop if isinstance(fxn, Tensor) else fxn, arg=grad_fxn), device=self.device)
return Tensor((fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn), device=self.device)
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
"""

View File

@@ -818,7 +818,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
src = (UOp(Ops.NOOP) if shape is None else shape_to_shape_arg(shape),) + (() if device is None else (UOp(Ops.DEVICE, arg=device),))
return UOp(Ops.PARAM, dtype, src, arg=slot)
def call(*srcs:UOp, fxn:UOp, arg:Any|None) -> UOp: return UOp(Ops.CALL, fxn.dtype, (fxn,)+srcs, arg)
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=()) -> UOp:
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata))
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
kernel = UOp(Ops.CUSTOM_KERNEL, src=contig_srcs, arg=CustomKernel(fxn=fxn, grad_fxn=grad_fxn))
@@ -843,6 +844,14 @@ class CustomKernel:
def __reduce__(self): return (CustomKernel, (panic,))
def __repr__(self): return f"CustomKernel({id(self.fxn)})"
@dataclass(frozen=True)
class CallInfo:
grad_fxn: Callable|None = None
metadata: tuple[Metadata, ...] = ()
# grad_fxn can't be pickled, but metadata can
def __reduce__(self): return (CallInfo, (None, self.metadata))
def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata})"
@dataclass(frozen=True)
class Kernel:
ast: UOp

View File

@@ -738,12 +738,13 @@ window.addEventListener("popstate", (e) => {
});
const createToggle = (id, text) => {
const label = d3.create("label").text(text).node();
const label = d3.create("label").style("display", "block").text(text).node();
const toggle = d3.create("input").attr("type", "checkbox").attr("id", id).property("checked", true).node();
label.prepend(toggle);
return { toggle, label };
}
const { toggle, label:toggleLabel } = createToggle("show-indexing", "Show indexing (r)");
const showIndexing = createToggle("show-indexing", "Show indexing (r)");
const showCallSrc = createToggle("show-call-src", "Show CALL src (c)");
const showGraph = createToggle("show-graph", "Show graph (g)");
showGraph.toggle.onchange = () => displaySelection(rect("#graph").width > 0 ? "#custom" : "#graph");
@@ -893,11 +894,13 @@ async function main() {
// ** center graph
const data = ret[currentRewrite];
const render = (opts) => renderDag({ data, opts }, { recenter:currentRewrite === 0 });
render({ showIndexing:toggle.checked });
toggle.onchange = (e) => render({ showIndexing:e.target.checked });
const getOpts = () => ({ showIndexing:showIndexing.toggle.checked, showCallSrc:showCallSrc.toggle.checked });
render(getOpts());
showIndexing.toggle.onchange = () => render(getOpts());
showCallSrc.toggle.onchange = () => render(getOpts());
// ** right sidebar metadata
metadata.innerHTML = "";
if (ckey.includes("rewrites")) metadata.appendChild(toggleLabel);
if (ckey.includes("rewrites")) metadata.append(showIndexing.label, showCallSrc.label);
if (step.code_line != null) metadata.appendChild(codeBlock(step.code_line, "python", { loc:step.loc, wrap:true }));
if (step.trace) {
const trace = d3.create("pre").append("code").classed("hljs", true);
@@ -1025,7 +1028,9 @@ document.addEventListener("keydown", (event) => {
document.getElementById("zoom-to-fit-btn").click();
}
// r key toggles indexing
if (event.key === "r") toggle.click();
if (event.key === "r") showIndexing.toggle.click();
// c key toggles CALL src
if (event.key === "c") showCallSrc.toggle.click();
// g key toggles graph
if (event.key === "g") showGraph.toggle.click();
});

View File

@@ -55,13 +55,42 @@ const layoutUOp = (g, { graph, change }, opts) => {
for (const [port, s] of src) g.setEdge(s, k, { label: edgeCounts[s] > 1 ? {type:"tag", text:edgeCounts[s]} : {type:"port", text:port}});
if (change?.includes(parseInt(k))) g.setParent(k, "overlay");
}
// optionally hide nodes from the layuot
// optionally hide nodes from the layout
if (!opts.showIndexing) {
for (const n of g.nodes()) {
const node = g.node(n);
if (node.label.includes("dtypes.index")) g.removeNode(n);
}
}
if (!opts.showCallSrc) {
// remove edges from src[0] to CALL nodes, track affected nodes
const disconnected = new Set();
for (const n of g.nodes()) {
const node = g.node(n);
if (node?.label?.startsWith("CALL\n") || node?.label === "CALL") {
for (const pred of (g.predecessors(n) || [])) {
const edge = g.edge(pred, n);
if (edge?.label?.text === 0) {
g.removeEdge(pred, n);
disconnected.add(pred);
}
}
}
}
// remove nodes that are now disconnected (no successors), only from affected subtree
let changed = true;
while (changed) {
changed = false;
for (const n of disconnected) {
if (!g.hasNode(n)) continue;
if ((g.successors(n) || []).length === 0) {
for (const pred of (g.predecessors(n) || [])) disconnected.add(pred);
g.removeNode(n);
changed = true;
}
}
}
}
dagre.layout(g);
// remove overlay node if it's empty
if (!g.node("overlay")?.width) g.removeNode("overlay");