mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
6
test/external/fuzz_viz.py
vendored
6
test/external/fuzz_viz.py
vendored
@@ -19,7 +19,7 @@ if __name__ == "__main__":
|
||||
list(lower_schedule(sched))
|
||||
uret = create_graph(contexts[0])
|
||||
assert uret.loc.split(":")[0] == "schedule.py"
|
||||
assert len(uret.graphs) == len(uret.extra) == 1
|
||||
assert len(uret.uops) == len(uret.extra) == 1
|
||||
assert len(uret.diffs) == 0
|
||||
contexts.clear()
|
||||
|
||||
@@ -33,8 +33,8 @@ if __name__ == "__main__":
|
||||
for ctx in tqdm(contexts):
|
||||
st = time.perf_counter()
|
||||
ret = create_graph(ctx)
|
||||
assert len(ret.graphs) == len(ret.extra)
|
||||
assert len(ret.diffs) == len(ret.graphs)-1, f"found {len(ret.diffs)} diffs but only {len(ret.graphs)-1} graphs"
|
||||
assert len(ret.uops) == len(ret.extra)
|
||||
assert len(ret.diffs) == len(ret.uops)-1, f"found {len(ret.diffs)} diffs but only {len(ret.uops)-1} uops"
|
||||
tms.append(time.perf_counter()-st)
|
||||
|
||||
timings = list(zip(contexts, tms))
|
||||
|
||||
50
test/test_viz.py
Normal file
50
test/test_viz.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# ** setup
|
||||
from typing import List
|
||||
import unittest
|
||||
import os
|
||||
prev_val = os.getenv("TRACK_MATCH_STATS")
|
||||
os.environ["TRACK_MATCH_STATS"] = "2"
|
||||
os.environ["FORWARD_ONLY"] = "1"
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import UOp, contexts
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.engine.realize import lower_schedule
|
||||
from test.external.process_replay.helpers import print_diff
|
||||
from viz.serve import create_graph, replace_uop
|
||||
|
||||
class TestViz(unittest.TestCase):
|
||||
def test_ctx_diff(self):
|
||||
a = Tensor.ones(4, 1).contiguous().realize()
|
||||
out = a + a.reshape(1, 4)
|
||||
out.realize()
|
||||
for ctx in contexts:
|
||||
uops = [ctx.sink]
|
||||
for i, (first, rewritten, pat) in enumerate(ctx.rewrites):
|
||||
start = uops[-1]
|
||||
found = [x for x in start.sparents if x.key == first.key]
|
||||
assert found, f"can't find UOp for rewrite_num={i} pattern={pat}"
|
||||
changed: List[UOp] = []
|
||||
new = replace_uop(start, first, rewritten, cache={})
|
||||
if DEBUG >= 4: print_diff(start, new)
|
||||
changed = [x for x in new.sparents if x not in start.sparents]
|
||||
assert len(changed) == len(found), f"{len(changed)} != {len(found)}"
|
||||
assert tuple(changed) == tuple(found), f"{changed} != {found}"
|
||||
|
||||
@unittest.skip("TODO: this graph doesn't change")
|
||||
def test_gemm_diff(self):
|
||||
x = Tensor.empty(64, 64).realize()
|
||||
y = Tensor.empty(64, 64).realize()
|
||||
out = x.matmul(y)
|
||||
contexts.clear()
|
||||
s = out.schedule()
|
||||
list(lower_schedule(s))
|
||||
ctx = contexts[3]
|
||||
ret = create_graph(ctx)
|
||||
for i, (x,y) in enumerate(zip(ret.uops, ret.uops[1:])):
|
||||
if x.key == y.key:
|
||||
raise AssertionError(f"failed to generate the correct diff at rewrite {i}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
if prev_val: os.environ["TRACK_MATCH_STATS"]
|
||||
else: del os.environ["TRACK_MATCH_STATS"]
|
||||
22
viz/serve.py
22
viz/serve.py
@@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Tuple
|
||||
import pickle, re, os, sys, time, threading, webbrowser, json, difflib, contextlib
|
||||
from tinygrad.helpers import getenv
|
||||
@@ -29,15 +29,15 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UOpRet:
|
||||
loc: str # location that called graph_rewrite
|
||||
graphs: List[Dict[int, Tuple[str, str, List[int], str, str]]] # a seralized version of UOp graphs
|
||||
diffs: List[Tuple[str, List[str]]] # the diffs for each rewrite
|
||||
extra: List[List[str]] # these become code blocks in the UI
|
||||
loc: str # location that called graph_rewrite
|
||||
uops: List[UOp] # snapshot of the entire AST after each rewrite
|
||||
diffs: List[Tuple[str, List[str]]] # the diffs for each rewrite
|
||||
extra: List[List[str]] # these become code blocks in the UI
|
||||
|
||||
def replace_uop(base:UOp, prev:UOp, new:UOp, cache:Dict[UOp, UOp]) -> UOp:
|
||||
def replace_uop(base:UOp, first:UOp, new:UOp, cache:Dict[UOp, UOp]) -> UOp:
|
||||
if (u:=cache.get(base)): return u
|
||||
new_srcs = tuple(new if x.key == prev.key else replace_uop(x, prev, new, cache) for x in base.src)
|
||||
ret = cache[base] = base if new_srcs == base.src else UOp(base.op, base.dtype, new_srcs, base.arg)
|
||||
if base.key == first.key: return new
|
||||
ret = cache[base] = UOp(base.op, base.dtype, tuple(replace_uop(x, first, new, cache) for x in base.src), base.arg)
|
||||
return ret
|
||||
|
||||
def create_graph(ctx:TrackedRewriteContext) -> UOpRet:
|
||||
@@ -51,7 +51,7 @@ def create_graph(ctx:TrackedRewriteContext) -> UOpRet:
|
||||
assert new_sink.op is UOps.SINK
|
||||
uops.append(new_sink)
|
||||
extra.append([str(new_sink)])
|
||||
return UOpRet(ctx.loc, list(map(uop_to_json, uops)), diffs, extra)
|
||||
return UOpRet(ctx.loc, uops, diffs, extra)
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
@@ -73,8 +73,8 @@ class Handler(BaseHTTPRequestHandler):
|
||||
self.end_headers()
|
||||
with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[TrackedRewriteContext] = pickle.load(f)
|
||||
rest = [x.loc for x in contexts]
|
||||
current_graph = create_graph(contexts[int(self.path.split("/")[-1])])
|
||||
ret = json.dumps((asdict(current_graph), rest)).encode()
|
||||
g = create_graph(contexts[int(self.path.split("/")[-1])])
|
||||
ret = json.dumps(({"loc": g.loc, "graphs": list(map(uop_to_json, g.uops)), "diffs": g.diffs, "extra": g.extra}, rest)).encode()
|
||||
else:
|
||||
self.send_response(404)
|
||||
ret = b""
|
||||
|
||||
Reference in New Issue
Block a user