start viz unittests (#6550)

* test_viz

* more tests
This commit is contained in:
qazal
2024-09-17 18:58:23 +08:00
committed by GitHub
parent 67a03e72bb
commit 455a27dd43
3 changed files with 64 additions and 14 deletions

View File

@@ -19,7 +19,7 @@ if __name__ == "__main__":
list(lower_schedule(sched))
uret = create_graph(contexts[0])
assert uret.loc.split(":")[0] == "schedule.py"
assert len(uret.graphs) == len(uret.extra) == 1
assert len(uret.uops) == len(uret.extra) == 1
assert len(uret.diffs) == 0
contexts.clear()
@@ -33,8 +33,8 @@ if __name__ == "__main__":
for ctx in tqdm(contexts):
st = time.perf_counter()
ret = create_graph(ctx)
assert len(ret.graphs) == len(ret.extra)
assert len(ret.diffs) == len(ret.graphs)-1, f"found {len(ret.diffs)} diffs but only {len(ret.graphs)-1} graphs"
assert len(ret.uops) == len(ret.extra)
assert len(ret.diffs) == len(ret.uops)-1, f"found {len(ret.diffs)} diffs but only {len(ret.uops)-1} uops"
tms.append(time.perf_counter()-st)
timings = list(zip(contexts, tms))

50
test/test_viz.py Normal file
View File

@@ -0,0 +1,50 @@
# ** setup
from typing import List
import unittest
import os
prev_val = os.getenv("TRACK_MATCH_STATS")
os.environ["TRACK_MATCH_STATS"] = "2"
os.environ["FORWARD_ONLY"] = "1"
from tinygrad.helpers import DEBUG
from tinygrad.ops import UOp, contexts
from tinygrad import Tensor
from tinygrad.engine.realize import lower_schedule
from test.external.process_replay.helpers import print_diff
from viz.serve import create_graph, replace_uop
class TestViz(unittest.TestCase):
def test_ctx_diff(self):
a = Tensor.ones(4, 1).contiguous().realize()
out = a + a.reshape(1, 4)
out.realize()
for ctx in contexts:
uops = [ctx.sink]
for i, (first, rewritten, pat) in enumerate(ctx.rewrites):
start = uops[-1]
found = [x for x in start.sparents if x.key == first.key]
assert found, f"can't find UOp for rewrite_num={i} pattern={pat}"
changed: List[UOp] = []
new = replace_uop(start, first, rewritten, cache={})
if DEBUG >= 4: print_diff(start, new)
changed = [x for x in new.sparents if x not in start.sparents]
assert len(changed) == len(found), f"{len(changed)} != {len(found)}"
assert tuple(changed) == tuple(found), f"{changed} != {found}"
@unittest.skip("TODO: this graph doesn't change")
def test_gemm_diff(self):
x = Tensor.empty(64, 64).realize()
y = Tensor.empty(64, 64).realize()
out = x.matmul(y)
contexts.clear()
s = out.schedule()
list(lower_schedule(s))
ctx = contexts[3]
ret = create_graph(ctx)
for i, (x,y) in enumerate(zip(ret.uops, ret.uops[1:])):
if x.key == y.key:
raise AssertionError(f"failed to generate the correct diff at rewrite {i}")
if __name__ == "__main__":
unittest.main()
if prev_val: os.environ["TRACK_MATCH_STATS"]
else: del os.environ["TRACK_MATCH_STATS"]

View File

@@ -1,5 +1,5 @@
#!/usr/bin/env python3
from dataclasses import asdict, dataclass
from dataclasses import dataclass
from typing import Dict, List, Tuple
import pickle, re, os, sys, time, threading, webbrowser, json, difflib, contextlib
from tinygrad.helpers import getenv
@@ -29,15 +29,15 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
@dataclass(frozen=True)
class UOpRet:
loc: str # location that called graph_rewrite
graphs: List[Dict[int, Tuple[str, str, List[int], str, str]]] # a seralized version of UOp graphs
diffs: List[Tuple[str, List[str]]] # the diffs for each rewrite
extra: List[List[str]] # these become code blocks in the UI
loc: str # location that called graph_rewrite
uops: List[UOp] # snapshot of the entire AST after each rewrite
diffs: List[Tuple[str, List[str]]] # the diffs for each rewrite
extra: List[List[str]] # these become code blocks in the UI
def replace_uop(base:UOp, prev:UOp, new:UOp, cache:Dict[UOp, UOp]) -> UOp:
def replace_uop(base:UOp, first:UOp, new:UOp, cache:Dict[UOp, UOp]) -> UOp:
if (u:=cache.get(base)): return u
new_srcs = tuple(new if x.key == prev.key else replace_uop(x, prev, new, cache) for x in base.src)
ret = cache[base] = base if new_srcs == base.src else UOp(base.op, base.dtype, new_srcs, base.arg)
if base.key == first.key: return new
ret = cache[base] = UOp(base.op, base.dtype, tuple(replace_uop(x, first, new, cache) for x in base.src), base.arg)
return ret
def create_graph(ctx:TrackedRewriteContext) -> UOpRet:
@@ -51,7 +51,7 @@ def create_graph(ctx:TrackedRewriteContext) -> UOpRet:
assert new_sink.op is UOps.SINK
uops.append(new_sink)
extra.append([str(new_sink)])
return UOpRet(ctx.loc, list(map(uop_to_json, uops)), diffs, extra)
return UOpRet(ctx.loc, uops, diffs, extra)
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
@@ -73,8 +73,8 @@ class Handler(BaseHTTPRequestHandler):
self.end_headers()
with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[TrackedRewriteContext] = pickle.load(f)
rest = [x.loc for x in contexts]
current_graph = create_graph(contexts[int(self.path.split("/")[-1])])
ret = json.dumps((asdict(current_graph), rest)).encode()
g = create_graph(contexts[int(self.path.split("/")[-1])])
ret = json.dumps(({"loc": g.loc, "graphs": list(map(uop_to_json, g.uops)), "diffs": g.diffs, "extra": g.extra}, rest)).encode()
else:
self.send_response(404)
ret = b""