mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
VIZ display cleanups (#14811)
* exclude reshape/expand broadcasts from viz * limit src lines
This commit is contained in:
23
test/external/external_test_llama3_layer.py
vendored
Normal file
23
test/external/external_test_llama3_layer.py
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
#!/usr/bin/env python3
|
||||
from tinygrad import Tensor, TinyJit, nn
|
||||
from tinygrad.helpers import getenv
|
||||
from extra.models.llama import TransformerBlock, precompute_freqs_cis
|
||||
|
||||
BS = getenv("BS", 1)
|
||||
SEQLEN = getenv("SEQLEN", 128)
|
||||
|
||||
# SEQLEN=8192 ASM_GEMM=1 HK_FLASH_ATTENTION=1 EMULATE=AMD_CDNA4 NULL=1 DEBUG=2 VIZ=1 PYTHONPATH="." python test/external/external_test_llama3_layer.py
|
||||
|
||||
if __name__ == "__main__":
|
||||
dim, hidden_dim, n_heads, n_kv_heads, norm_eps = 4096, 14336, 32, 8, 1e-5
|
||||
layer = TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context=0)
|
||||
for x in nn.state.get_parameters(layer): x.replace(x.half()).realize()
|
||||
|
||||
freqs_cis = precompute_freqs_cis(dim // n_heads, SEQLEN, theta=500000.0).contiguous().requires_grad_(False).realize()
|
||||
|
||||
@TinyJit
|
||||
def run(t): return layer(t, 0, freqs_cis, None)
|
||||
|
||||
for i in range(5):
|
||||
print(f"*** run {i}")
|
||||
run(Tensor.rand(BS, SEQLEN, dim).half().realize())
|
||||
@@ -188,6 +188,20 @@ class TestViz(BaseTestViz):
|
||||
self.assertEqual(list(graphs[0]), [id(a), id(alu)])
|
||||
self.assertEqual(list(graphs[1]), [id(z)])
|
||||
|
||||
def test_const_reshape_expand_folded(self):
|
||||
# CONST->RESHAPE->EXPAND should be folded into the ALU node, not shown as separate RESHAPE/EXPAND nodes
|
||||
c = UOp.const(dtypes.float, 1.0, device="CPU", shape=(3,4)) # creates CONST->RESHAPE->EXPAND chain
|
||||
a = UOp(Ops.DEFINE_VAR, dtypes.float, arg=("a", 0.0, 10.0))
|
||||
alu = a + c
|
||||
graph = uop_to_json(alu)
|
||||
# the RESHAPE and EXPAND nodes from the const should not appear in the graph
|
||||
labels = {v["label"].split("\n")[0] for v in graph.values()}
|
||||
self.assertNotIn("RESHAPE", labels)
|
||||
self.assertNotIn("EXPAND", labels)
|
||||
# the CONST should be inlined into the ALU node's label
|
||||
alu_label = graph[id(alu)]["label"]
|
||||
self.assertIn("CONST", alu_label)
|
||||
|
||||
# VIZ displays nested graph_rewrites in a tree view
|
||||
|
||||
def leaf_rewrite(x:UOp): return x.rtag(1) if x.tag is None else None
|
||||
|
||||
@@ -103,6 +103,8 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
||||
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u)
|
||||
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.index and u is not x: excluded.add(u)
|
||||
if u.op is Ops.VECTORIZE and len(u.src) == 0: excluded.add(u)
|
||||
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
|
||||
if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)
|
||||
for u in toposort:
|
||||
if u in excluded: continue
|
||||
argst = codecs.decode(str(u.arg), "unicode_escape")
|
||||
@@ -113,8 +115,11 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
||||
if u.dtype != dtypes.void: label += f"\n{u.dtype}"
|
||||
for idx,x in enumerate(u.src[:1] if u.op in {Ops.BUFFERIZE, Ops.INDEX} else (u.src if u.op is not Ops.END else [])):
|
||||
if x in excluded:
|
||||
arg = f"{x.arg:g}" if x.op is Ops.CONST and dtypes.is_float(x.dtype) else f"{x.arg}"
|
||||
label += f"\n{x.op.name}{idx} {arg}" + (f" {x.src[0].op}" if len(x.src) else "")
|
||||
# walk through excluded movement ops to find the underlying CONST
|
||||
cx = x
|
||||
while cx.op in GroupOp.Movement and len(cx.src) >= 1 and cx.src[0] in excluded: cx = cx.src[0]
|
||||
arg = f"{cx.arg:g}" if cx.op is Ops.CONST and dtypes.is_float(cx.dtype) else f"{cx.arg}"
|
||||
label += f"\n{cx.op.name}{idx} {arg}" + (f" {cx.src[0].op}" if len(cx.src) else "")
|
||||
try:
|
||||
if len(rngs:=u.ranges):
|
||||
label += f"\n({multirange_str(rngs, color=True)})"
|
||||
@@ -132,6 +137,9 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
||||
if (ref:=ref_map.get(u.src[0]) if u.op is Ops.CALL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"
|
||||
# NOTE: kernel already has metadata in arg
|
||||
if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.CALL: label += "\n"+str(u.metadata)
|
||||
# limit SOURCE labels line count
|
||||
if u.op is Ops.SOURCE and len(lines:=label.split("\n")) > 40:
|
||||
label = "\n".join(lines[:30]) + "\n..."
|
||||
graph[id(u)] = {"label":label, "src":[(i,id(x)) for i,x in enumerate(u.src) if x not in excluded], "color":uops_colors.get(u.op, "#ffffff"),
|
||||
"ref":ref, "tag":repr(u.tag) if u.tag is not None else None}
|
||||
return graph
|
||||
|
||||
Reference in New Issue
Block a user