VIZ display cleanups (#14811)

* exclude reshape/expand broadcasts from viz

* limit src lines
This commit is contained in:
George Hotz
2026-02-17 10:03:08 +08:00
committed by GitHub
parent 5bca5be2d2
commit bc3487d607
3 changed files with 47 additions and 2 deletions

View File

@@ -0,0 +1,23 @@
#!/usr/bin/env python3
from tinygrad import Tensor, TinyJit, nn
from tinygrad.helpers import getenv
from extra.models.llama import TransformerBlock, precompute_freqs_cis
BS = getenv("BS", 1)
SEQLEN = getenv("SEQLEN", 128)
# SEQLEN=8192 ASM_GEMM=1 HK_FLASH_ATTENTION=1 EMULATE=AMD_CDNA4 NULL=1 DEBUG=2 VIZ=1 PYTHONPATH="." python test/external/external_test_llama3_layer.py
if __name__ == "__main__":
dim, hidden_dim, n_heads, n_kv_heads, norm_eps = 4096, 14336, 32, 8, 1e-5
layer = TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context=0)
for x in nn.state.get_parameters(layer): x.replace(x.half()).realize()
freqs_cis = precompute_freqs_cis(dim // n_heads, SEQLEN, theta=500000.0).contiguous().requires_grad_(False).realize()
@TinyJit
def run(t): return layer(t, 0, freqs_cis, None)
for i in range(5):
print(f"*** run {i}")
run(Tensor.rand(BS, SEQLEN, dim).half().realize())

View File

@@ -188,6 +188,20 @@ class TestViz(BaseTestViz):
self.assertEqual(list(graphs[0]), [id(a), id(alu)])
self.assertEqual(list(graphs[1]), [id(z)])
def test_const_reshape_expand_folded(self):
# CONST->RESHAPE->EXPAND should be folded into the ALU node, not shown as separate RESHAPE/EXPAND nodes
c = UOp.const(dtypes.float, 1.0, device="CPU", shape=(3,4)) # creates CONST->RESHAPE->EXPAND chain
a = UOp(Ops.DEFINE_VAR, dtypes.float, arg=("a", 0.0, 10.0))
alu = a + c
graph = uop_to_json(alu)
# the RESHAPE and EXPAND nodes from the const should not appear in the graph
labels = {v["label"].split("\n")[0] for v in graph.values()}
self.assertNotIn("RESHAPE", labels)
self.assertNotIn("EXPAND", labels)
# the CONST should be inlined into the ALU node's label
alu_label = graph[id(alu)]["label"]
self.assertIn("CONST", alu_label)
# VIZ displays nested graph_rewrites in a tree view
def leaf_rewrite(x:UOp): return x.rtag(1) if x.tag is None else None

View File

@@ -103,6 +103,8 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u)
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.index and u is not x: excluded.add(u)
if u.op is Ops.VECTORIZE and len(u.src) == 0: excluded.add(u)
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)
for u in toposort:
if u in excluded: continue
argst = codecs.decode(str(u.arg), "unicode_escape")
@@ -113,8 +115,11 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
if u.dtype != dtypes.void: label += f"\n{u.dtype}"
for idx,x in enumerate(u.src[:1] if u.op in {Ops.BUFFERIZE, Ops.INDEX} else (u.src if u.op is not Ops.END else [])):
if x in excluded:
arg = f"{x.arg:g}" if x.op is Ops.CONST and dtypes.is_float(x.dtype) else f"{x.arg}"
label += f"\n{x.op.name}{idx} {arg}" + (f" {x.src[0].op}" if len(x.src) else "")
# walk through excluded movement ops to find the underlying CONST
cx = x
while cx.op in GroupOp.Movement and len(cx.src) >= 1 and cx.src[0] in excluded: cx = cx.src[0]
arg = f"{cx.arg:g}" if cx.op is Ops.CONST and dtypes.is_float(cx.dtype) else f"{cx.arg}"
label += f"\n{cx.op.name}{idx} {arg}" + (f" {cx.src[0].op}" if len(cx.src) else "")
try:
if len(rngs:=u.ranges):
label += f"\n({multirange_str(rngs, color=True)})"
@@ -132,6 +137,9 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
if (ref:=ref_map.get(u.src[0]) if u.op is Ops.CALL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"
# NOTE: kernel already has metadata in arg
if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.CALL: label += "\n"+str(u.metadata)
# limit SOURCE labels line count
if u.op is Ops.SOURCE and len(lines:=label.split("\n")) > 40:
label = "\n".join(lines[:30]) + "\n..."
graph[id(u)] = {"label":label, "src":[(i,id(x)) for i,x in enumerate(u.src) if x not in excluded], "color":uops_colors.get(u.op, "#ffffff"),
"ref":ref, "tag":repr(u.tag) if u.tag is not None else None}
return graph