diff --git a/test/external/external_test_llama3_layer.py b/test/external/external_test_llama3_layer.py new file mode 100644 index 0000000000..0ec028620e --- /dev/null +++ b/test/external/external_test_llama3_layer.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +from tinygrad import Tensor, TinyJit, nn +from tinygrad.helpers import getenv +from extra.models.llama import TransformerBlock, precompute_freqs_cis + +BS = getenv("BS", 1) +SEQLEN = getenv("SEQLEN", 128) + +# SEQLEN=8192 ASM_GEMM=1 HK_FLASH_ATTENTION=1 EMULATE=AMD_CDNA4 NULL=1 DEBUG=2 VIZ=1 PYTHONPATH="." python test/external/external_test_llama3_layer.py + +if __name__ == "__main__": + dim, hidden_dim, n_heads, n_kv_heads, norm_eps = 4096, 14336, 32, 8, 1e-5 + layer = TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context=0) + for x in nn.state.get_parameters(layer): x.replace(x.half()).realize() + + freqs_cis = precompute_freqs_cis(dim // n_heads, SEQLEN, theta=500000.0).contiguous().requires_grad_(False).realize() + + @TinyJit + def run(t): return layer(t, 0, freqs_cis, None) + + for i in range(5): + print(f"*** run {i}") + run(Tensor.rand(BS, SEQLEN, dim).half().realize()) diff --git a/test/null/test_viz.py b/test/null/test_viz.py index a89cf56b32..731cccd67a 100644 --- a/test/null/test_viz.py +++ b/test/null/test_viz.py @@ -188,6 +188,20 @@ class TestViz(BaseTestViz): self.assertEqual(list(graphs[0]), [id(a), id(alu)]) self.assertEqual(list(graphs[1]), [id(z)]) + def test_const_reshape_expand_folded(self): + # CONST->RESHAPE->EXPAND should be folded into the ALU node, not shown as separate RESHAPE/EXPAND nodes + c = UOp.const(dtypes.float, 1.0, device="CPU", shape=(3,4)) # creates CONST->RESHAPE->EXPAND chain + a = UOp(Ops.DEFINE_VAR, dtypes.float, arg=("a", 0.0, 10.0)) + alu = a + c + graph = uop_to_json(alu) + # the RESHAPE and EXPAND nodes from the const should not appear in the graph + labels = {v["label"].split("\n")[0] for v in graph.values()} + self.assertNotIn("RESHAPE", labels) + self.assertNotIn("EXPAND", labels) + # the CONST should be inlined into the ALU node's label + alu_label = graph[id(alu)]["label"] + self.assertIn("CONST", alu_label) + # VIZ displays nested graph_rewrites in a tree view def leaf_rewrite(x:UOp): return x.rtag(1) if x.tag is None else None diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 820e5a2eaf..05988552bc 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -103,6 +103,8 @@ def uop_to_json(x:UOp) -> dict[int, dict]: if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u) if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.index and u is not x: excluded.add(u) if u.op is Ops.VECTORIZE and len(u.src) == 0: excluded.add(u) + # exclude RESHAPE/EXPAND that only serve to broadcast a CONST + if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u) for u in toposort: if u in excluded: continue argst = codecs.decode(str(u.arg), "unicode_escape") @@ -113,8 +115,11 @@ def uop_to_json(x:UOp) -> dict[int, dict]: if u.dtype != dtypes.void: label += f"\n{u.dtype}" for idx,x in enumerate(u.src[:1] if u.op in {Ops.BUFFERIZE, Ops.INDEX} else (u.src if u.op is not Ops.END else [])): if x in excluded: - arg = f"{x.arg:g}" if x.op is Ops.CONST and dtypes.is_float(x.dtype) else f"{x.arg}" - label += f"\n{x.op.name}{idx} {arg}" + (f" {x.src[0].op}" if len(x.src) else "") + # walk through excluded movement ops to find the underlying CONST + cx = x + while cx.op in GroupOp.Movement and len(cx.src) >= 1 and cx.src[0] in excluded: cx = cx.src[0] + arg = f"{cx.arg:g}" if cx.op is Ops.CONST and dtypes.is_float(cx.dtype) else f"{cx.arg}" + label += f"\n{cx.op.name}{idx} {arg}" + (f" {cx.src[0].op}" if len(cx.src) else "") try: if len(rngs:=u.ranges): label += f"\n({multirange_str(rngs, color=True)})" @@ -132,6 +137,9 @@ def uop_to_json(x:UOp) -> dict[int, dict]: if (ref:=ref_map.get(u.src[0]) if u.op is Ops.CALL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}" # NOTE: kernel already has metadata in arg if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.CALL: label += "\n"+str(u.metadata) + # limit SOURCE labels line count + if u.op is Ops.SOURCE and len(lines:=label.split("\n")) > 40: + label = "\n".join(lines[:30]) + "\n..." graph[id(u)] = {"label":label, "src":[(i,id(x)) for i,x in enumerate(u.src) if x not in excluded], "color":uops_colors.get(u.op, "#ffffff"), "ref":ref, "tag":repr(u.tag) if u.tag is not None else None} return graph