From d32f5e9f3a3aec6abbfe1c96680b5402bef5fb17 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 28 Apr 2025 16:44:09 -0400 Subject: [PATCH] improve rendering of shapes in viz + investigate symbolic [pr] (#10091) --- test/test_symbolic_ops.py | 15 +++++++++++---- tinygrad/ops.py | 2 ++ tinygrad/runtime/support/am/amdev.py | 2 +- tinygrad/viz/serve.py | 10 ++++++---- 4 files changed, 20 insertions(+), 9 deletions(-) diff --git a/test/test_symbolic_ops.py b/test/test_symbolic_ops.py index b877b2178b..48d3c2c0af 100644 --- a/test/test_symbolic_ops.py +++ b/test/test_symbolic_ops.py @@ -1,6 +1,6 @@ import unittest from tinygrad import Variable -from tinygrad.helpers import Context +from tinygrad.helpers import Context, GlobalCounters from tinygrad.tensor import Tensor from examples.gpt2 import Attention import numpy as np @@ -43,17 +43,24 @@ class TestSymbolicOps(unittest.TestCase): expected = f(a, b).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - def test_attention(self, dropout_p=0.0): + def test_attention(self, dropout_p=0.0, imin=1, imax=5, use_symbolic=True): def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p).realize() - for i in range(1, 5): - vi = Variable("i", 1, 10).bind(i) + for i in range(imin, imax): + vi = Variable("i", 1, 10).bind(i) if use_symbolic else i q = Tensor.rand(2, 1, 4, 8) k = Tensor.rand(2, i, 4, 8) v = Tensor.rand(2, i, 4, 8) + Tensor.realize(q, k, v) + GlobalCounters.reset() symbolic = f(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).numpy() expected = f(q, k, v).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + def test_attention_cmp_symbolic(self): + # symbolic isn't seeing if i == i, so it's not putting them on the same axis + self.test_attention(imin=4, imax=5, use_symbolic=False) + self.test_attention(imin=4, imax=5, use_symbolic=True) + def test_attention_training(self): with Tensor.train(): self.test_attention(dropout_p=0.0) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 004093bcd8..4b01e19896 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -207,6 +207,7 @@ def _suop(lst, uop_fxn, python_fxn): return ssimplify(functools.reduce(uop_fxn, uops + ([python_fxn(nums)] if nums else []))) def smax(*lst): return _suop(argfix(*lst), UOp.maximum, max) def smin(*lst): return _suop(argfix(*lst), UOp.minimum, min) +def srender(x) -> str: return x.render() if isinstance(x, UOp) else str(x) def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop def sym_infer(uop: Union[UOp, int], var_vals: dict[UOp, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop @@ -1001,6 +1002,7 @@ renderer = PatternMatcher([ (UPat(Ops.CAST, name="x"), lambda x: UOp(Ops.NOOP, arg=f"({str(x.dtype)[7:]})({x.src[0].arg})")), (UPat(Ops.LOAD), lambda: UOp(Ops.NOOP, arg="load")), (UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]), + #(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}[={x.src[1].arg}]")), (UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")), (UPat(Ops.MAX, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")), (UPat(Ops.MULACC, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")), diff --git a/tinygrad/runtime/support/am/amdev.py b/tinygrad/runtime/support/am/amdev.py index 249386dc23..f7e1636727 100644 --- a/tinygrad/runtime/support/am/amdev.py +++ b/tinygrad/runtime/support/am/amdev.py @@ -71,7 +71,7 @@ class AMFirmware: self.descs += [self.desc(blob, imu_i_off, imu_i_sz, am.GFX_FW_TYPE_IMU_I), self.desc(blob, imu_i_off + imu_i_sz, imu_d_sz, am.GFX_FW_TYPE_IMU_D)] # RLC firmware - blob, hdr0, hdr1, hdr2, hdr3 = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_rlc.bin", am.struct_rlc_firmware_header_v2_0, + blob, hdr0, _hdr1, hdr2, hdr3 = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_rlc.bin", am.struct_rlc_firmware_header_v2_0, am.struct_rlc_firmware_header_v2_1, am.struct_rlc_firmware_header_v2_2, am.struct_rlc_firmware_header_v2_3) for mem,fmem in [('IRAM', 'iram'), ('DRAM_BOOT', 'dram')]: diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 6584248d58..2408401178 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -4,7 +4,7 @@ from http.server import BaseHTTPRequestHandler from urllib.parse import parse_qs, urlparse from typing import Any, Callable, TypedDict, Generator from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap -from tinygrad.ops import TrackedGraphRewrite, UOp, Ops, lines, GroupOp +from tinygrad.ops import TrackedGraphRewrite, UOp, Ops, lines, GroupOp, srender, sint from tinygrad.codegen.kernel import Kernel from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent from tinygrad.dtype import dtypes @@ -51,6 +51,8 @@ class GraphRewriteDetails(TypedDict): changed_nodes: list[int]|None # the changed UOp id + all its parents ids upat: tuple[tuple[str, int], str]|None # [loc, source_code] of the matched UPat +def shape_to_str(s:tuple[sint, ...]): return "(" + ','.join(srender(x) for x in s) + ")" + def uop_to_json(x:UOp) -> dict[int, dict]: assert isinstance(x, UOp) graph: dict[int, dict] = {} @@ -65,8 +67,8 @@ def uop_to_json(x:UOp) -> dict[int, dict]: if u in excluded: continue argst = str(u.arg) if u.op is Ops.VIEW: - argst = ("\n".join([f"{v.shape} / {v.strides}"+(f"\nMASK {v.mask}" if v.mask is not None else "")+ - ("" if v.offset == 0 else f" / {v.offset}") for v in unwrap(u.st).views])) + argst = ("\n".join([f"{shape_to_str(v.shape)} / {shape_to_str(v.strides)}"+(f"\nMASK {v.mask}" if v.mask is not None else "")+ + ("" if v.offset == 0 else f" / {srender(v.offset)}") for v in unwrap(u.st).views])) label = f"{str(u.op).split('.')[1]}{(chr(10)+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}" if u.dtype != dtypes.void: label += f"\n{u.dtype}" for idx,x in enumerate(u.src): @@ -75,7 +77,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]: else: label += f"\n{x.op.name}{idx} {x.arg}" try: if u.op not in {Ops.VIEW, Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u.st is not None: - label += f"\n{repr(u.shape)}" + label += f"\n{shape_to_str(u.shape)}" except Exception: label += "\n" graph[id(u)] = {"label":label, "src":[id(x) for x in u.src if x not in excluded], "color":uops_colors.get(u.op, "#ffffff")}