mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add CallInfo and viz call toggle (#14570)
This commit is contained in:
@@ -14,7 +14,7 @@ def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
|
||||
if op == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],)
|
||||
|
||||
def call_gradient(ctx:UOp, k:UOp):
|
||||
if k.arg is not None: return (None,) + k.arg(ctx, k)
|
||||
if k.arg.grad_fxn is not None: return (None,) + k.arg.grad_fxn(ctx, k)
|
||||
# auto-differentiate the function
|
||||
fxn, args = k.src[0], k.src[1:]
|
||||
params = sorted([x for x in fxn.toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
|
||||
|
||||
@@ -240,7 +240,7 @@ class Tensor(OpMixin):
|
||||
param = UOp.param(slot, self.dtype, self.shape, self.device)
|
||||
return Tensor(param, device=self.device)
|
||||
def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor:
|
||||
return Tensor(UOp.call(*[t.uop for t in (self,)+lst], fxn=fxn.uop if isinstance(fxn, Tensor) else fxn, arg=grad_fxn), device=self.device)
|
||||
return Tensor((fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn), device=self.device)
|
||||
|
||||
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
|
||||
"""
|
||||
|
||||
@@ -818,7 +818,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
src = (UOp(Ops.NOOP) if shape is None else shape_to_shape_arg(shape),) + (() if device is None else (UOp(Ops.DEVICE, arg=device),))
|
||||
return UOp(Ops.PARAM, dtype, src, arg=slot)
|
||||
|
||||
def call(*srcs:UOp, fxn:UOp, arg:Any|None) -> UOp: return UOp(Ops.CALL, fxn.dtype, (fxn,)+srcs, arg)
|
||||
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=()) -> UOp:
|
||||
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata))
|
||||
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
||||
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
|
||||
kernel = UOp(Ops.CUSTOM_KERNEL, src=contig_srcs, arg=CustomKernel(fxn=fxn, grad_fxn=grad_fxn))
|
||||
@@ -843,6 +844,14 @@ class CustomKernel:
|
||||
def __reduce__(self): return (CustomKernel, (panic,))
|
||||
def __repr__(self): return f"CustomKernel({id(self.fxn)})"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CallInfo:
|
||||
grad_fxn: Callable|None = None
|
||||
metadata: tuple[Metadata, ...] = ()
|
||||
# grad_fxn can't be pickled, but metadata can
|
||||
def __reduce__(self): return (CallInfo, (None, self.metadata))
|
||||
def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata})"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Kernel:
|
||||
ast: UOp
|
||||
|
||||
@@ -738,12 +738,13 @@ window.addEventListener("popstate", (e) => {
|
||||
});
|
||||
|
||||
const createToggle = (id, text) => {
|
||||
const label = d3.create("label").text(text).node();
|
||||
const label = d3.create("label").style("display", "block").text(text).node();
|
||||
const toggle = d3.create("input").attr("type", "checkbox").attr("id", id).property("checked", true).node();
|
||||
label.prepend(toggle);
|
||||
return { toggle, label };
|
||||
}
|
||||
const { toggle, label:toggleLabel } = createToggle("show-indexing", "Show indexing (r)");
|
||||
const showIndexing = createToggle("show-indexing", "Show indexing (r)");
|
||||
const showCallSrc = createToggle("show-call-src", "Show CALL src (c)");
|
||||
const showGraph = createToggle("show-graph", "Show graph (g)");
|
||||
showGraph.toggle.onchange = () => displaySelection(rect("#graph").width > 0 ? "#custom" : "#graph");
|
||||
|
||||
@@ -893,11 +894,13 @@ async function main() {
|
||||
// ** center graph
|
||||
const data = ret[currentRewrite];
|
||||
const render = (opts) => renderDag({ data, opts }, { recenter:currentRewrite === 0 });
|
||||
render({ showIndexing:toggle.checked });
|
||||
toggle.onchange = (e) => render({ showIndexing:e.target.checked });
|
||||
const getOpts = () => ({ showIndexing:showIndexing.toggle.checked, showCallSrc:showCallSrc.toggle.checked });
|
||||
render(getOpts());
|
||||
showIndexing.toggle.onchange = () => render(getOpts());
|
||||
showCallSrc.toggle.onchange = () => render(getOpts());
|
||||
// ** right sidebar metadata
|
||||
metadata.innerHTML = "";
|
||||
if (ckey.includes("rewrites")) metadata.appendChild(toggleLabel);
|
||||
if (ckey.includes("rewrites")) metadata.append(showIndexing.label, showCallSrc.label);
|
||||
if (step.code_line != null) metadata.appendChild(codeBlock(step.code_line, "python", { loc:step.loc, wrap:true }));
|
||||
if (step.trace) {
|
||||
const trace = d3.create("pre").append("code").classed("hljs", true);
|
||||
@@ -1025,7 +1028,9 @@ document.addEventListener("keydown", (event) => {
|
||||
document.getElementById("zoom-to-fit-btn").click();
|
||||
}
|
||||
// r key toggles indexing
|
||||
if (event.key === "r") toggle.click();
|
||||
if (event.key === "r") showIndexing.toggle.click();
|
||||
// c key toggles CALL src
|
||||
if (event.key === "c") showCallSrc.toggle.click();
|
||||
// g key toggles graph
|
||||
if (event.key === "g") showGraph.toggle.click();
|
||||
});
|
||||
|
||||
@@ -55,13 +55,42 @@ const layoutUOp = (g, { graph, change }, opts) => {
|
||||
for (const [port, s] of src) g.setEdge(s, k, { label: edgeCounts[s] > 1 ? {type:"tag", text:edgeCounts[s]} : {type:"port", text:port}});
|
||||
if (change?.includes(parseInt(k))) g.setParent(k, "overlay");
|
||||
}
|
||||
// optionally hide nodes from the layuot
|
||||
// optionally hide nodes from the layout
|
||||
if (!opts.showIndexing) {
|
||||
for (const n of g.nodes()) {
|
||||
const node = g.node(n);
|
||||
if (node.label.includes("dtypes.index")) g.removeNode(n);
|
||||
}
|
||||
}
|
||||
if (!opts.showCallSrc) {
|
||||
// remove edges from src[0] to CALL nodes, track affected nodes
|
||||
const disconnected = new Set();
|
||||
for (const n of g.nodes()) {
|
||||
const node = g.node(n);
|
||||
if (node?.label?.startsWith("CALL\n") || node?.label === "CALL") {
|
||||
for (const pred of (g.predecessors(n) || [])) {
|
||||
const edge = g.edge(pred, n);
|
||||
if (edge?.label?.text === 0) {
|
||||
g.removeEdge(pred, n);
|
||||
disconnected.add(pred);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// remove nodes that are now disconnected (no successors), only from affected subtree
|
||||
let changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
for (const n of disconnected) {
|
||||
if (!g.hasNode(n)) continue;
|
||||
if ((g.successors(n) || []).length === 0) {
|
||||
for (const pred of (g.predecessors(n) || [])) disconnected.add(pred);
|
||||
g.removeNode(n);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
dagre.layout(g);
|
||||
// remove overlay node if it's empty
|
||||
if (!g.node("overlay")?.width) g.removeNode("overlay");
|
||||
|
||||
Reference in New Issue
Block a user