new viz fuzz tests, track multiple contexts (#6913)

* add FUZZ_VIZ option

* add FUZZ_VIZ=1 tests

* use .replace

* rewrites test

* add rewrite_stack

* add FUZZ_VIZ to ops

* what if FUZZ_VIZ was up there

* leave fuzz_viz for now
This commit is contained in:
qazal
2024-10-06 14:58:15 +03:00
committed by GitHub
parent 75d9dcf000
commit 837f9c6832
3 changed files with 22 additions and 26 deletions

View File

@@ -590,6 +590,7 @@ class TrackedRewriteContext:
kernel: Optional[Kernel] = None # the kernel being rewritten
rewrites: List[Tuple[UOp, UOp, UPat]] = field(default_factory=list) # all rewrites of sparents. (before, after, UPat)
contexts: List[TrackedRewriteContext] = []
rewrite_stack: List[TrackedRewriteContext] = []
class TrackedPatternMatcher(PatternMatcher):
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
super().__init__(patterns)
@@ -610,7 +611,7 @@ class TrackedPatternMatcher(PatternMatcher):
match_stats[p][2] += (et:=time.perf_counter()-st)
match_stats[p][3] += et
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable())
if TRACK_MATCH_STATS >= 2 and contexts and isinstance(ret, UOp): contexts[-1].rewrites.append((uop, ret, p))
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0 and isinstance(ret, UOp): rewrite_stack[-1].rewrites.append((uop, ret, p))
return ret # NOTE: if it returns None, we keep trying to match
match_stats[p][2] += time.perf_counter()-st
return None
@@ -656,8 +657,10 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
# get Kernel we are rewriting in the context of
frm_walk: Optional[FrameType] = frm
while frm_walk is not None and not isinstance(kernel:=frm_walk.f_locals.get("self", None), Kernel): kernel, frm_walk = None, frm_walk.f_back
contexts.append(TrackedRewriteContext((frm.f_code.co_filename, frm.f_lineno), sink, kernel))
return RewriteContext(pm, ctx).rewrite(sink)
rewrite_stack.append(TrackedRewriteContext((frm.f_code.co_filename, frm.f_lineno), sink, kernel))
ret = RewriteContext(pm, ctx).rewrite(sink)
if TRACK_MATCH_STATS >= 2: contexts.append(rewrite_stack.pop())
return ret
# ***** uop type spec *****

View File

@@ -5,22 +5,23 @@ import pickle, os, sys, time, threading, webbrowser, json, difflib, contextlib,
from dataclasses import asdict
from urllib.parse import parse_qs, urlparse
from http.server import HTTPServer, BaseHTTPRequestHandler
from tinygrad.helpers import getenv, to_function_name
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, UPat, lines
from tinygrad.helpers import getenv, to_function_name, tqdm
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, lines
from tinygrad.engine.graph import uops_colors, word_wrap
from viz.spec import GraphRewriteDetails, GraphRewriteMetadata
def reconstruct_graph(sink:UOp, rewrites:List[Tuple[UOp, UOp, UPat]]) -> Tuple[List[UOp], List[List[str]], List[List[int]]]:
uops: List[UOp] = [sink]
def reconstruct_graph(ctx:TrackedRewriteContext) -> Tuple[List[UOp], List[List[str]], List[List[int]]]:
uops: List[UOp] = [ctx.sink]
diffs: List[List[str]] = []
changed_nodes: List[List[int]] = []
seen_replaces: Dict[UOp, UOp] = {}
for i, (first, rewritten, _) in enumerate(rewrites):
for i, (first, rewritten, upat) in enumerate(ctx.rewrites):
# first, rewrite this UOp with the current rewrite + all the seen rewrites before this
seen_replaces[first] = rewritten
new_sink = replace_uop(uops[-1], {**seen_replaces})
# sanity check
assert new_sink is not uops[-1], f"rewritten sink wasn't rewritten! {i}\n{new_sink}\n{uops[-1]}"
if new_sink is uops[-1]:
raise AssertionError(f"rewritten sink wasn't rewritten! {i} {upat.location}")
# update ret data
changed_nodes.append([id(x) for x in rewritten.sparents if x.op is not UOps.CONST])
diffs.append(list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines())))
@@ -43,7 +44,7 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
def replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp:
if (found:=replaces.get(base)) is not None: return found
replaces[base] = ret = UOp(base.op, base.dtype, tuple(replace_uop(x, replaces) for x in base.src), base.arg)
replaces[base] = ret = base.replace(src=tuple(replace_uop(x, replaces) for x in base.src))
return ret
def load_kernels(contexts) -> DefaultDict[str, List[Tuple[GraphRewriteMetadata, TrackedRewriteContext]]]:
@@ -79,7 +80,7 @@ class Handler(BaseHTTPRequestHandler):
query = parse_qs(url.query)
if (qkernel:=query.get("kernel")) is not None:
metadata, ctx = list(kernels.values())[int(qkernel[0])][int(query["idx"][0])]
graphs, diffs, changed_nodes = reconstruct_graph(ctx.sink, ctx.rewrites)
graphs, diffs, changed_nodes = reconstruct_graph(ctx)
ret = json.dumps(asdict(GraphRewriteDetails(**asdict(metadata), graphs=list(map(uop_to_json, graphs)),
diffs=diffs, changed_nodes=changed_nodes, kernel_code=get_src(ctx.kernel)))).encode()
else: ret = json.dumps([list(map(lambda x:asdict(x[0]), v)) for v in kernels.values()]).encode()
@@ -104,6 +105,9 @@ if __name__ == "__main__":
with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[TrackedRewriteContext] = pickle.load(f)
print("*** unpickled saved rewrites")
kernels = load_kernels(contexts)
if getenv("FUZZ_VIZ"):
for v in tqdm(kernels.values()):
for _,ctx in v: reconstruct_graph(ctx)
print("*** loaded kernels")
server = HTTPServer(('', 8000), Handler)
st = time.perf_counter()

View File

@@ -1,19 +1,17 @@
from typing import List
import unittest
import os, itertools
from viz.spec import GraphRewriteMetadata
os.environ["TRACK_MATCH_STATS"] = "2"
os.environ["PRINT_MATCH_STATS"] = "0"
from extra.models.resnet import ResNet50
from tinygrad import Tensor
from tinygrad.engine.realize import lower_schedule
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, graph_rewrite, PatternMatcher, UPat, contexts, KernelInfo, BinaryOps
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.helpers import CI, Context, all_same, DEBUG, colored, getenv
from tinygrad.helpers import Context, all_same, DEBUG, colored, getenv
from tinygrad.codegen.uopgraph import sym, devectorize, float4_folding
from test.external.process_replay.helpers import print_diff
from viz.serve import reconstruct_graph, uop_to_json, load_kernels
from viz.spec import GraphRewriteMetadata
def group_rewrites(kernels:List[GraphRewriteMetadata]): return {k:list(v) for k,v in itertools.groupby(kernels, lambda x:x.loc)}
@@ -25,7 +23,7 @@ class TestViz(unittest.TestCase):
def assert_valid_ctx(self, contexts:List[TrackedRewriteContext]):
assert len(contexts) != 0
for i,ctx in enumerate(contexts):
try: graphs,_,_ = reconstruct_graph(ctx.sink, ctx.rewrites)
try: graphs,_,_ = reconstruct_graph(ctx)
except Exception as e:
print(colored(f"failed to create graph for ctx {i}", "red"))
raise e
@@ -72,7 +70,7 @@ class TestViz(unittest.TestCase):
])
ret = graph_rewrite(sink, pm)
if DEBUG >= 4: print_diff(sink, ret)
graphs,_,_ = reconstruct_graph(contexts[0].sink, contexts[0].rewrites)
graphs,_,_ = reconstruct_graph(contexts[0])
assert graphs[-1].key == ret.key
self.assert_valid_ctx(contexts)
@@ -106,15 +104,6 @@ class TestViz(unittest.TestCase):
self.assert_valid_ctx(contexts)
assert all(ctx.loc[0].split("/")[-1] == __file__.split("/")[-1] for ctx in contexts)
@unittest.skipIf(CI, "slow, it's generating diffs for 36202 rules")
def test_fuzz_resnet(self):
mdl = ResNet50()
img = Tensor.empty(64, 3, 224, 224)
out = mdl(img)
sched = out.schedule()
list(lower_schedule(sched))
self.assert_valid_ctx(contexts)
def test_no_ctx(self):
simple_pm = PatternMatcher([(UPat(UOps.CONST), lambda:True)])
simple_pm.rewrite(UOp.const(dtypes.int, 2))