viz more work [run_process_replay] (#6568)

* infra

* found it

* real work

* bring those back

* cleanup test_viz

* comment that out
This commit is contained in:
qazal
2024-09-17 19:27:09 +08:00
committed by GitHub
parent 455a27dd43
commit 9295bc0189
3 changed files with 34 additions and 41 deletions

View File

@@ -2,7 +2,7 @@
from typing import Tuple
from extra.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer, MetaOps
from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.ops import UOp, UOps
from tinygrad.ops import UOp, UOps, KernelInfo
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View

View File

@@ -1,50 +1,38 @@
# ** setup
from typing import List
import unittest
import os
prev_val = os.getenv("TRACK_MATCH_STATS")
os.environ["TRACK_MATCH_STATS"] = "2"
os.environ["FORWARD_ONLY"] = "1"
from tinygrad.helpers import DEBUG
from tinygrad.ops import UOp, contexts
from tinygrad import Tensor
from tinygrad.engine.realize import lower_schedule
from test.external.process_replay.helpers import print_diff
from viz.serve import create_graph, replace_uop
@unittest.skip("TODO: this graph doesn't change")
class TestViz(unittest.TestCase):
def setUp(self) -> None:
os.environ["TRACK_MATCH_STATS"] = "2"
def tearDown(self) -> None:
os.environ["TRACK_MATCH_STATS"] = "0"
def assert_valid_graph(self, t):
from tinygrad.ops import contexts
from tinygrad.engine.realize import lower_schedule
from viz.serve import create_graph
contexts.clear()
s = t.schedule()
list(lower_schedule(s))
for i,ctx in enumerate(contexts):
ret = create_graph(ctx)
for j,(x,y) in enumerate(zip(ret.uops, ret.uops[1:])):
if x.key == y.key:
raise AssertionError(f"failed to generate the correct diff at rewrite {j} ctx {i}")
def test_ctx_diff(self):
from tinygrad import Tensor
a = Tensor.ones(4, 1).contiguous().realize()
out = a + a.reshape(1, 4)
out.realize()
for ctx in contexts:
uops = [ctx.sink]
for i, (first, rewritten, pat) in enumerate(ctx.rewrites):
start = uops[-1]
found = [x for x in start.sparents if x.key == first.key]
assert found, f"can't find UOp for rewrite_num={i} pattern={pat}"
changed: List[UOp] = []
new = replace_uop(start, first, rewritten, cache={})
if DEBUG >= 4: print_diff(start, new)
changed = [x for x in new.sparents if x not in start.sparents]
assert len(changed) == len(found), f"{len(changed)} != {len(found)}"
assert tuple(changed) == tuple(found), f"{changed} != {found}"
self.assert_valid_graph(out)
@unittest.skip("TODO: this graph doesn't change")
def test_gemm_diff(self):
from tinygrad import Tensor
x = Tensor.empty(64, 64).realize()
y = Tensor.empty(64, 64).realize()
out = x.matmul(y)
contexts.clear()
s = out.schedule()
list(lower_schedule(s))
ctx = contexts[3]
ret = create_graph(ctx)
for i, (x,y) in enumerate(zip(ret.uops, ret.uops[1:])):
if x.key == y.key:
raise AssertionError(f"failed to generate the correct diff at rewrite {i}")
self.assert_valid_graph(out)
if __name__ == "__main__":
unittest.main()
if prev_val: os.environ["TRACK_MATCH_STATS"]
else: del os.environ["TRACK_MATCH_STATS"]

View File

@@ -34,10 +34,13 @@ class UOpRet:
diffs: List[Tuple[str, List[str]]] # the diffs for each rewrite
extra: List[List[str]] # these become code blocks in the UI
def replace_uop(base:UOp, first:UOp, new:UOp, cache:Dict[UOp, UOp]) -> UOp:
if (u:=cache.get(base)): return u
if base.key == first.key: return new
ret = cache[base] = UOp(base.op, base.dtype, tuple(replace_uop(x, first, new, cache) for x in base.src), base.arg)
def replace_uop(base:UOp, prev:UOp, new:UOp, cache:Dict[bytes, UOp]) -> UOp:
if (found:=cache.get(base.key)): return found
if base.key == prev.key: ret = new
else:
new_srcs = tuple(replace_uop(x, prev, new, cache) for x in base.src)
ret = UOp(base.op, base.dtype, new_srcs, base.arg) if new_srcs != base.src else base
cache[base.key] = ret
return ret
def create_graph(ctx:TrackedRewriteContext) -> UOpRet:
@@ -45,9 +48,11 @@ def create_graph(ctx:TrackedRewriteContext) -> UOpRet:
diffs: List[Tuple[str, List[str]]] = []
extra: List[List[str]] = [[str(ctx.sink)]]
for (first, rewritten, pattern) in ctx.rewrites:
diffs.append((pattern, list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines()))))
# if the sink was replaced, we have to replace the entire graph, otherwise just replace the parent
new_sink = rewritten if first.op is UOps.SINK else replace_uop(uops[-1], first, rewritten, {})
# TODO: sometimes it hits a ctx and can't find any UOp to replace
#if new_sink is uops[-1]: continue
diffs.append((pattern, list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines()))))
assert new_sink.op is UOps.SINK
uops.append(new_sink)
extra.append([str(new_sink)])