From 455a27dd4361b4af88309ee88a4b045becb3f05a Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 17 Sep 2024 18:58:23 +0800 Subject: [PATCH] start viz unittests (#6550) * test_viz * more tests --- test/external/fuzz_viz.py | 6 ++--- test/test_viz.py | 50 +++++++++++++++++++++++++++++++++++++++ viz/serve.py | 22 ++++++++--------- 3 files changed, 64 insertions(+), 14 deletions(-) create mode 100644 test/test_viz.py diff --git a/test/external/fuzz_viz.py b/test/external/fuzz_viz.py index 4dc84aefe8..e7a5c5b46e 100644 --- a/test/external/fuzz_viz.py +++ b/test/external/fuzz_viz.py @@ -19,7 +19,7 @@ if __name__ == "__main__": list(lower_schedule(sched)) uret = create_graph(contexts[0]) assert uret.loc.split(":")[0] == "schedule.py" - assert len(uret.graphs) == len(uret.extra) == 1 + assert len(uret.uops) == len(uret.extra) == 1 assert len(uret.diffs) == 0 contexts.clear() @@ -33,8 +33,8 @@ if __name__ == "__main__": for ctx in tqdm(contexts): st = time.perf_counter() ret = create_graph(ctx) - assert len(ret.graphs) == len(ret.extra) - assert len(ret.diffs) == len(ret.graphs)-1, f"found {len(ret.diffs)} diffs but only {len(ret.graphs)-1} graphs" + assert len(ret.uops) == len(ret.extra) + assert len(ret.diffs) == len(ret.uops)-1, f"found {len(ret.diffs)} diffs but only {len(ret.uops)-1} uops" tms.append(time.perf_counter()-st) timings = list(zip(contexts, tms)) diff --git a/test/test_viz.py b/test/test_viz.py new file mode 100644 index 0000000000..4e4ec1feaa --- /dev/null +++ b/test/test_viz.py @@ -0,0 +1,50 @@ +# ** setup +from typing import List +import unittest +import os +prev_val = os.getenv("TRACK_MATCH_STATS") +os.environ["TRACK_MATCH_STATS"] = "2" +os.environ["FORWARD_ONLY"] = "1" +from tinygrad.helpers import DEBUG +from tinygrad.ops import UOp, contexts +from tinygrad import Tensor +from tinygrad.engine.realize import lower_schedule +from test.external.process_replay.helpers import print_diff +from viz.serve import create_graph, replace_uop + +class TestViz(unittest.TestCase): + def test_ctx_diff(self): + a = Tensor.ones(4, 1).contiguous().realize() + out = a + a.reshape(1, 4) + out.realize() + for ctx in contexts: + uops = [ctx.sink] + for i, (first, rewritten, pat) in enumerate(ctx.rewrites): + start = uops[-1] + found = [x for x in start.sparents if x.key == first.key] + assert found, f"can't find UOp for rewrite_num={i} pattern={pat}" + changed: List[UOp] = [] + new = replace_uop(start, first, rewritten, cache={}) + if DEBUG >= 4: print_diff(start, new) + changed = [x for x in new.sparents if x not in start.sparents] + assert len(changed) == len(found), f"{len(changed)} != {len(found)}" + assert tuple(changed) == tuple(found), f"{changed} != {found}" + + @unittest.skip("TODO: this graph doesn't change") + def test_gemm_diff(self): + x = Tensor.empty(64, 64).realize() + y = Tensor.empty(64, 64).realize() + out = x.matmul(y) + contexts.clear() + s = out.schedule() + list(lower_schedule(s)) + ctx = contexts[3] + ret = create_graph(ctx) + for i, (x,y) in enumerate(zip(ret.uops, ret.uops[1:])): + if x.key == y.key: + raise AssertionError(f"failed to generate the correct diff at rewrite {i}") + +if __name__ == "__main__": + unittest.main() + if prev_val: os.environ["TRACK_MATCH_STATS"] + else: del os.environ["TRACK_MATCH_STATS"] diff --git a/viz/serve.py b/viz/serve.py index 6b8deba387..d90cf9a1b0 100755 --- a/viz/serve.py +++ b/viz/serve.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -from dataclasses import asdict, dataclass +from dataclasses import dataclass from typing import Dict, List, Tuple import pickle, re, os, sys, time, threading, webbrowser, json, difflib, contextlib from tinygrad.helpers import getenv @@ -29,15 +29,15 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]: @dataclass(frozen=True) class UOpRet: - loc: str # location that called graph_rewrite - graphs: List[Dict[int, Tuple[str, str, List[int], str, str]]] # a seralized version of UOp graphs - diffs: List[Tuple[str, List[str]]] # the diffs for each rewrite - extra: List[List[str]] # these become code blocks in the UI + loc: str # location that called graph_rewrite + uops: List[UOp] # snapshot of the entire AST after each rewrite + diffs: List[Tuple[str, List[str]]] # the diffs for each rewrite + extra: List[List[str]] # these become code blocks in the UI -def replace_uop(base:UOp, prev:UOp, new:UOp, cache:Dict[UOp, UOp]) -> UOp: +def replace_uop(base:UOp, first:UOp, new:UOp, cache:Dict[UOp, UOp]) -> UOp: if (u:=cache.get(base)): return u - new_srcs = tuple(new if x.key == prev.key else replace_uop(x, prev, new, cache) for x in base.src) - ret = cache[base] = base if new_srcs == base.src else UOp(base.op, base.dtype, new_srcs, base.arg) + if base.key == first.key: return new + ret = cache[base] = UOp(base.op, base.dtype, tuple(replace_uop(x, first, new, cache) for x in base.src), base.arg) return ret def create_graph(ctx:TrackedRewriteContext) -> UOpRet: @@ -51,7 +51,7 @@ def create_graph(ctx:TrackedRewriteContext) -> UOpRet: assert new_sink.op is UOps.SINK uops.append(new_sink) extra.append([str(new_sink)]) - return UOpRet(ctx.loc, list(map(uop_to_json, uops)), diffs, extra) + return UOpRet(ctx.loc, uops, diffs, extra) class Handler(BaseHTTPRequestHandler): def do_GET(self): @@ -73,8 +73,8 @@ class Handler(BaseHTTPRequestHandler): self.end_headers() with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[TrackedRewriteContext] = pickle.load(f) rest = [x.loc for x in contexts] - current_graph = create_graph(contexts[int(self.path.split("/")[-1])]) - ret = json.dumps((asdict(current_graph), rest)).encode() + g = create_graph(contexts[int(self.path.split("/")[-1])]) + ret = json.dumps(({"loc": g.loc, "graphs": list(map(uop_to_json, g.uops)), "diffs": g.diffs, "extra": g.extra}, rest)).encode() else: self.send_response(404) ret = b""