mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
viz more work [run_process_replay] (#6568)
* infra * found it * real work * bring those back * cleanup test_viz * comment that out
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
from typing import Tuple
|
||||
from extra.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer, MetaOps
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
from tinygrad.ops import UOp, UOps
|
||||
from tinygrad.ops import UOp, UOps, KernelInfo
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
|
||||
@@ -1,50 +1,38 @@
|
||||
# ** setup
|
||||
from typing import List
|
||||
import unittest
|
||||
import os
|
||||
prev_val = os.getenv("TRACK_MATCH_STATS")
|
||||
os.environ["TRACK_MATCH_STATS"] = "2"
|
||||
os.environ["FORWARD_ONLY"] = "1"
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import UOp, contexts
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.engine.realize import lower_schedule
|
||||
from test.external.process_replay.helpers import print_diff
|
||||
from viz.serve import create_graph, replace_uop
|
||||
|
||||
@unittest.skip("TODO: this graph doesn't change")
|
||||
class TestViz(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
os.environ["TRACK_MATCH_STATS"] = "2"
|
||||
def tearDown(self) -> None:
|
||||
os.environ["TRACK_MATCH_STATS"] = "0"
|
||||
|
||||
def assert_valid_graph(self, t):
|
||||
from tinygrad.ops import contexts
|
||||
from tinygrad.engine.realize import lower_schedule
|
||||
from viz.serve import create_graph
|
||||
contexts.clear()
|
||||
s = t.schedule()
|
||||
list(lower_schedule(s))
|
||||
for i,ctx in enumerate(contexts):
|
||||
ret = create_graph(ctx)
|
||||
for j,(x,y) in enumerate(zip(ret.uops, ret.uops[1:])):
|
||||
if x.key == y.key:
|
||||
raise AssertionError(f"failed to generate the correct diff at rewrite {j} ctx {i}")
|
||||
|
||||
def test_ctx_diff(self):
|
||||
from tinygrad import Tensor
|
||||
a = Tensor.ones(4, 1).contiguous().realize()
|
||||
out = a + a.reshape(1, 4)
|
||||
out.realize()
|
||||
for ctx in contexts:
|
||||
uops = [ctx.sink]
|
||||
for i, (first, rewritten, pat) in enumerate(ctx.rewrites):
|
||||
start = uops[-1]
|
||||
found = [x for x in start.sparents if x.key == first.key]
|
||||
assert found, f"can't find UOp for rewrite_num={i} pattern={pat}"
|
||||
changed: List[UOp] = []
|
||||
new = replace_uop(start, first, rewritten, cache={})
|
||||
if DEBUG >= 4: print_diff(start, new)
|
||||
changed = [x for x in new.sparents if x not in start.sparents]
|
||||
assert len(changed) == len(found), f"{len(changed)} != {len(found)}"
|
||||
assert tuple(changed) == tuple(found), f"{changed} != {found}"
|
||||
self.assert_valid_graph(out)
|
||||
|
||||
@unittest.skip("TODO: this graph doesn't change")
|
||||
def test_gemm_diff(self):
|
||||
from tinygrad import Tensor
|
||||
x = Tensor.empty(64, 64).realize()
|
||||
y = Tensor.empty(64, 64).realize()
|
||||
out = x.matmul(y)
|
||||
contexts.clear()
|
||||
s = out.schedule()
|
||||
list(lower_schedule(s))
|
||||
ctx = contexts[3]
|
||||
ret = create_graph(ctx)
|
||||
for i, (x,y) in enumerate(zip(ret.uops, ret.uops[1:])):
|
||||
if x.key == y.key:
|
||||
raise AssertionError(f"failed to generate the correct diff at rewrite {i}")
|
||||
self.assert_valid_graph(out)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
if prev_val: os.environ["TRACK_MATCH_STATS"]
|
||||
else: del os.environ["TRACK_MATCH_STATS"]
|
||||
|
||||
15
viz/serve.py
15
viz/serve.py
@@ -34,10 +34,13 @@ class UOpRet:
|
||||
diffs: List[Tuple[str, List[str]]] # the diffs for each rewrite
|
||||
extra: List[List[str]] # these become code blocks in the UI
|
||||
|
||||
def replace_uop(base:UOp, first:UOp, new:UOp, cache:Dict[UOp, UOp]) -> UOp:
|
||||
if (u:=cache.get(base)): return u
|
||||
if base.key == first.key: return new
|
||||
ret = cache[base] = UOp(base.op, base.dtype, tuple(replace_uop(x, first, new, cache) for x in base.src), base.arg)
|
||||
def replace_uop(base:UOp, prev:UOp, new:UOp, cache:Dict[bytes, UOp]) -> UOp:
|
||||
if (found:=cache.get(base.key)): return found
|
||||
if base.key == prev.key: ret = new
|
||||
else:
|
||||
new_srcs = tuple(replace_uop(x, prev, new, cache) for x in base.src)
|
||||
ret = UOp(base.op, base.dtype, new_srcs, base.arg) if new_srcs != base.src else base
|
||||
cache[base.key] = ret
|
||||
return ret
|
||||
|
||||
def create_graph(ctx:TrackedRewriteContext) -> UOpRet:
|
||||
@@ -45,9 +48,11 @@ def create_graph(ctx:TrackedRewriteContext) -> UOpRet:
|
||||
diffs: List[Tuple[str, List[str]]] = []
|
||||
extra: List[List[str]] = [[str(ctx.sink)]]
|
||||
for (first, rewritten, pattern) in ctx.rewrites:
|
||||
diffs.append((pattern, list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines()))))
|
||||
# if the sink was replaced, we have to replace the entire graph, otherwise just replace the parent
|
||||
new_sink = rewritten if first.op is UOps.SINK else replace_uop(uops[-1], first, rewritten, {})
|
||||
# TODO: sometimes it hits a ctx and can't find any UOp to replace
|
||||
#if new_sink is uops[-1]: continue
|
||||
diffs.append((pattern, list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines()))))
|
||||
assert new_sink.op is UOps.SINK
|
||||
uops.append(new_sink)
|
||||
extra.append([str(new_sink)])
|
||||
|
||||
Reference in New Issue
Block a user