VIZ display cleanups (#14811)

* exclude reshape/expand broadcasts from viz

* limit src lines
This commit is contained in:
George Hotz
2026-02-17 10:03:08 +08:00
committed by GitHub
parent 5bca5be2d2
commit bc3487d607
3 changed files with 47 additions and 2 deletions

View File

@@ -0,0 +1,23 @@
#!/usr/bin/env python3
from tinygrad import Tensor, TinyJit, nn
from tinygrad.helpers import getenv
from extra.models.llama import TransformerBlock, precompute_freqs_cis
BS = getenv("BS", 1)
SEQLEN = getenv("SEQLEN", 128)
# SEQLEN=8192 ASM_GEMM=1 HK_FLASH_ATTENTION=1 EMULATE=AMD_CDNA4 NULL=1 DEBUG=2 VIZ=1 PYTHONPATH="." python test/external/external_test_llama3_layer.py
if __name__ == "__main__":
dim, hidden_dim, n_heads, n_kv_heads, norm_eps = 4096, 14336, 32, 8, 1e-5
layer = TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context=0)
for x in nn.state.get_parameters(layer): x.replace(x.half()).realize()
freqs_cis = precompute_freqs_cis(dim // n_heads, SEQLEN, theta=500000.0).contiguous().requires_grad_(False).realize()
@TinyJit
def run(t): return layer(t, 0, freqs_cis, None)
for i in range(5):
print(f"*** run {i}")
run(Tensor.rand(BS, SEQLEN, dim).half().realize())

View File

@@ -188,6 +188,20 @@ class TestViz(BaseTestViz):
self.assertEqual(list(graphs[0]), [id(a), id(alu)])
self.assertEqual(list(graphs[1]), [id(z)])
def test_const_reshape_expand_folded(self):
# CONST->RESHAPE->EXPAND should be folded into the ALU node, not shown as separate RESHAPE/EXPAND nodes
c = UOp.const(dtypes.float, 1.0, device="CPU", shape=(3,4)) # creates CONST->RESHAPE->EXPAND chain
a = UOp(Ops.DEFINE_VAR, dtypes.float, arg=("a", 0.0, 10.0))
alu = a + c
graph = uop_to_json(alu)
# the RESHAPE and EXPAND nodes from the const should not appear in the graph
labels = {v["label"].split("\n")[0] for v in graph.values()}
self.assertNotIn("RESHAPE", labels)
self.assertNotIn("EXPAND", labels)
# the CONST should be inlined into the ALU node's label
alu_label = graph[id(alu)]["label"]
self.assertIn("CONST", alu_label)
# VIZ displays nested graph_rewrites in a tree view
def leaf_rewrite(x:UOp): return x.rtag(1) if x.tag is None else None