mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 23:08:06 -05:00
new viz fuzz tests, track multiple contexts (#6913)
* add FUZZ_VIZ option * add FUZZ_VIZ=1 tests * use .replace * rewrites test * add rewrite_stack * add FUZZ_VIZ to ops * what if FUZZ_VIZ was up there * leave fuzz_viz for now
This commit is contained in:
@@ -590,6 +590,7 @@ class TrackedRewriteContext:
|
||||
kernel: Optional[Kernel] = None # the kernel being rewritten
|
||||
rewrites: List[Tuple[UOp, UOp, UPat]] = field(default_factory=list) # all rewrites of sparents. (before, after, UPat)
|
||||
contexts: List[TrackedRewriteContext] = []
|
||||
rewrite_stack: List[TrackedRewriteContext] = []
|
||||
class TrackedPatternMatcher(PatternMatcher):
|
||||
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
|
||||
super().__init__(patterns)
|
||||
@@ -610,7 +611,7 @@ class TrackedPatternMatcher(PatternMatcher):
|
||||
match_stats[p][2] += (et:=time.perf_counter()-st)
|
||||
match_stats[p][3] += et
|
||||
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable())
|
||||
if TRACK_MATCH_STATS >= 2 and contexts and isinstance(ret, UOp): contexts[-1].rewrites.append((uop, ret, p))
|
||||
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0 and isinstance(ret, UOp): rewrite_stack[-1].rewrites.append((uop, ret, p))
|
||||
return ret # NOTE: if it returns None, we keep trying to match
|
||||
match_stats[p][2] += time.perf_counter()-st
|
||||
return None
|
||||
@@ -656,8 +657,10 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
|
||||
# get Kernel we are rewriting in the context of
|
||||
frm_walk: Optional[FrameType] = frm
|
||||
while frm_walk is not None and not isinstance(kernel:=frm_walk.f_locals.get("self", None), Kernel): kernel, frm_walk = None, frm_walk.f_back
|
||||
contexts.append(TrackedRewriteContext((frm.f_code.co_filename, frm.f_lineno), sink, kernel))
|
||||
return RewriteContext(pm, ctx).rewrite(sink)
|
||||
rewrite_stack.append(TrackedRewriteContext((frm.f_code.co_filename, frm.f_lineno), sink, kernel))
|
||||
ret = RewriteContext(pm, ctx).rewrite(sink)
|
||||
if TRACK_MATCH_STATS >= 2: contexts.append(rewrite_stack.pop())
|
||||
return ret
|
||||
|
||||
# ***** uop type spec *****
|
||||
|
||||
|
||||
20
viz/serve.py
20
viz/serve.py
@@ -5,22 +5,23 @@ import pickle, os, sys, time, threading, webbrowser, json, difflib, contextlib,
|
||||
from dataclasses import asdict
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
from tinygrad.helpers import getenv, to_function_name
|
||||
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, UPat, lines
|
||||
from tinygrad.helpers import getenv, to_function_name, tqdm
|
||||
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, lines
|
||||
from tinygrad.engine.graph import uops_colors, word_wrap
|
||||
from viz.spec import GraphRewriteDetails, GraphRewriteMetadata
|
||||
|
||||
def reconstruct_graph(sink:UOp, rewrites:List[Tuple[UOp, UOp, UPat]]) -> Tuple[List[UOp], List[List[str]], List[List[int]]]:
|
||||
uops: List[UOp] = [sink]
|
||||
def reconstruct_graph(ctx:TrackedRewriteContext) -> Tuple[List[UOp], List[List[str]], List[List[int]]]:
|
||||
uops: List[UOp] = [ctx.sink]
|
||||
diffs: List[List[str]] = []
|
||||
changed_nodes: List[List[int]] = []
|
||||
seen_replaces: Dict[UOp, UOp] = {}
|
||||
for i, (first, rewritten, _) in enumerate(rewrites):
|
||||
for i, (first, rewritten, upat) in enumerate(ctx.rewrites):
|
||||
# first, rewrite this UOp with the current rewrite + all the seen rewrites before this
|
||||
seen_replaces[first] = rewritten
|
||||
new_sink = replace_uop(uops[-1], {**seen_replaces})
|
||||
# sanity check
|
||||
assert new_sink is not uops[-1], f"rewritten sink wasn't rewritten! {i}\n{new_sink}\n{uops[-1]}"
|
||||
if new_sink is uops[-1]:
|
||||
raise AssertionError(f"rewritten sink wasn't rewritten! {i} {upat.location}")
|
||||
# update ret data
|
||||
changed_nodes.append([id(x) for x in rewritten.sparents if x.op is not UOps.CONST])
|
||||
diffs.append(list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines())))
|
||||
@@ -43,7 +44,7 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
|
||||
|
||||
def replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp:
|
||||
if (found:=replaces.get(base)) is not None: return found
|
||||
replaces[base] = ret = UOp(base.op, base.dtype, tuple(replace_uop(x, replaces) for x in base.src), base.arg)
|
||||
replaces[base] = ret = base.replace(src=tuple(replace_uop(x, replaces) for x in base.src))
|
||||
return ret
|
||||
|
||||
def load_kernels(contexts) -> DefaultDict[str, List[Tuple[GraphRewriteMetadata, TrackedRewriteContext]]]:
|
||||
@@ -79,7 +80,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
query = parse_qs(url.query)
|
||||
if (qkernel:=query.get("kernel")) is not None:
|
||||
metadata, ctx = list(kernels.values())[int(qkernel[0])][int(query["idx"][0])]
|
||||
graphs, diffs, changed_nodes = reconstruct_graph(ctx.sink, ctx.rewrites)
|
||||
graphs, diffs, changed_nodes = reconstruct_graph(ctx)
|
||||
ret = json.dumps(asdict(GraphRewriteDetails(**asdict(metadata), graphs=list(map(uop_to_json, graphs)),
|
||||
diffs=diffs, changed_nodes=changed_nodes, kernel_code=get_src(ctx.kernel)))).encode()
|
||||
else: ret = json.dumps([list(map(lambda x:asdict(x[0]), v)) for v in kernels.values()]).encode()
|
||||
@@ -104,6 +105,9 @@ if __name__ == "__main__":
|
||||
with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[TrackedRewriteContext] = pickle.load(f)
|
||||
print("*** unpickled saved rewrites")
|
||||
kernels = load_kernels(contexts)
|
||||
if getenv("FUZZ_VIZ"):
|
||||
for v in tqdm(kernels.values()):
|
||||
for _,ctx in v: reconstruct_graph(ctx)
|
||||
print("*** loaded kernels")
|
||||
server = HTTPServer(('', 8000), Handler)
|
||||
st = time.perf_counter()
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
from typing import List
|
||||
import unittest
|
||||
import os, itertools
|
||||
|
||||
from viz.spec import GraphRewriteMetadata
|
||||
os.environ["TRACK_MATCH_STATS"] = "2"
|
||||
os.environ["PRINT_MATCH_STATS"] = "0"
|
||||
from extra.models.resnet import ResNet50
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.engine.realize import lower_schedule
|
||||
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, graph_rewrite, PatternMatcher, UPat, contexts, KernelInfo, BinaryOps
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.helpers import CI, Context, all_same, DEBUG, colored, getenv
|
||||
from tinygrad.helpers import Context, all_same, DEBUG, colored, getenv
|
||||
from tinygrad.codegen.uopgraph import sym, devectorize, float4_folding
|
||||
from test.external.process_replay.helpers import print_diff
|
||||
from viz.serve import reconstruct_graph, uop_to_json, load_kernels
|
||||
from viz.spec import GraphRewriteMetadata
|
||||
|
||||
def group_rewrites(kernels:List[GraphRewriteMetadata]): return {k:list(v) for k,v in itertools.groupby(kernels, lambda x:x.loc)}
|
||||
|
||||
@@ -25,7 +23,7 @@ class TestViz(unittest.TestCase):
|
||||
def assert_valid_ctx(self, contexts:List[TrackedRewriteContext]):
|
||||
assert len(contexts) != 0
|
||||
for i,ctx in enumerate(contexts):
|
||||
try: graphs,_,_ = reconstruct_graph(ctx.sink, ctx.rewrites)
|
||||
try: graphs,_,_ = reconstruct_graph(ctx)
|
||||
except Exception as e:
|
||||
print(colored(f"failed to create graph for ctx {i}", "red"))
|
||||
raise e
|
||||
@@ -72,7 +70,7 @@ class TestViz(unittest.TestCase):
|
||||
])
|
||||
ret = graph_rewrite(sink, pm)
|
||||
if DEBUG >= 4: print_diff(sink, ret)
|
||||
graphs,_,_ = reconstruct_graph(contexts[0].sink, contexts[0].rewrites)
|
||||
graphs,_,_ = reconstruct_graph(contexts[0])
|
||||
assert graphs[-1].key == ret.key
|
||||
self.assert_valid_ctx(contexts)
|
||||
|
||||
@@ -106,15 +104,6 @@ class TestViz(unittest.TestCase):
|
||||
self.assert_valid_ctx(contexts)
|
||||
assert all(ctx.loc[0].split("/")[-1] == __file__.split("/")[-1] for ctx in contexts)
|
||||
|
||||
@unittest.skipIf(CI, "slow, it's generating diffs for 36202 rules")
|
||||
def test_fuzz_resnet(self):
|
||||
mdl = ResNet50()
|
||||
img = Tensor.empty(64, 3, 224, 224)
|
||||
out = mdl(img)
|
||||
sched = out.schedule()
|
||||
list(lower_schedule(sched))
|
||||
self.assert_valid_ctx(contexts)
|
||||
|
||||
def test_no_ctx(self):
|
||||
simple_pm = PatternMatcher([(UPat(UOps.CONST), lambda:True)])
|
||||
simple_pm.rewrite(UOp.const(dtypes.int, 2))
|
||||
|
||||
Reference in New Issue
Block a user