mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
move test_viz to test/, prereq for tinygrad/viz [pr] (#6972)
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -345,8 +345,6 @@ jobs:
|
||||
run: PYTHONPATH="." METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=2 FUZZ_N=24 FUZZ_MAX_SIZE=10000000 python test/external/fuzz_linearizer.py
|
||||
- name: Fuzz Test models schedule
|
||||
run: FUZZ_SCHEDULE=1 FUZZ_SCHEDULE_MAX_PATHS=5 python -m pytest test/models/test_train.py test/models/test_end2end.py
|
||||
- name: Run VIZ=1 tests
|
||||
run: PYTHONPATH="." TRACK_MATCH_STATS=2 METAL=1 python3 -m pytest viz/test_viz.py
|
||||
- name: Run TRANSCENDENTAL math
|
||||
run: TRANSCENDENTAL=2 python -m pytest -n=auto test/test_ops.py::TestOps::test_sin test/test_ops.py::TestOps::test_cos test/test_ops.py::TestOps::test_tan test/test_ops.py::TestOps::test_exp test/test_ops.py::TestOps::test_log --durations=20
|
||||
- name: Run process replay tests
|
||||
|
||||
111
test/test_viz.py
Normal file
111
test/test_viz.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from typing import List
|
||||
import unittest
|
||||
import itertools
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.helpers import Context, getenv
|
||||
from tinygrad.engine.realize import lower_schedule
|
||||
from viz.serve import GraphRewriteMetadata, get_metadata, _uop_to_json
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher, UPat, UOps, UOp, graph_rewrite, contexts, track_rewrites
|
||||
|
||||
def group_rewrites(kernels:List[GraphRewriteMetadata]): return {k:list(v) for k,v in itertools.groupby(kernels, lambda x:x.loc)}
|
||||
|
||||
class TestViz(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
contexts.clear()
|
||||
self.prev_val = TRACK_MATCH_STATS.value
|
||||
TRACK_MATCH_STATS.value = 2
|
||||
def tearDown(self) -> None:
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, contexts
|
||||
if not getenv("VIZ"): contexts.clear()
|
||||
TRACK_MATCH_STATS.value = self.prev_val
|
||||
|
||||
def assert_valid_ctx(self):
|
||||
from tinygrad.ops import contexts
|
||||
assert len(contexts) != 0
|
||||
return get_metadata(contexts)
|
||||
|
||||
def assert_valid_graph(self, t):
|
||||
s = t.schedule()
|
||||
list(lower_schedule(s))
|
||||
self.assert_valid_ctx()
|
||||
|
||||
def test_ctx_diff(self):
|
||||
a = Tensor.ones(4, 1).contiguous().realize()
|
||||
out = a + a.reshape(1, 4)
|
||||
self.assert_valid_graph(out)
|
||||
|
||||
def test_ctx_groups(self):
|
||||
schedule1 = Tensor.zeros(4, 1).contiguous().exp().schedule()
|
||||
schedule2 = Tensor.zeros(4, 1).contiguous().exp().schedule()
|
||||
list(lower_schedule(schedule1))
|
||||
list(lower_schedule(schedule2))
|
||||
ret = self.assert_valid_ctx()
|
||||
assert len(ret) == 3
|
||||
assert all(len([x for _,_,x in y if "schedule" in x.loc[0]]) == 0 for y in ret[1:])
|
||||
assert all(len([x for _,_,x in y if "uopgraph" in x.loc[0]]) != 0 for y in ret[1:])
|
||||
|
||||
def test_gemm_diff(self):
|
||||
x = Tensor.empty(64, 64).realize()
|
||||
y = Tensor.empty(64, 64).realize()
|
||||
out = x.matmul(y)
|
||||
self.assert_valid_graph(out)
|
||||
|
||||
def test_track_no_ctx(self):
|
||||
@track_rewrites
|
||||
def simplify_and_verify(u:UOp):
|
||||
simplify = TrackedPatternMatcher([(UPat.var("x")*1, lambda x:x)])
|
||||
verify = TrackedPatternMatcher([(UPat(UOps.CONST), lambda:True)])
|
||||
verify.rewrite(graph_rewrite(u, simplify))
|
||||
u = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))*1
|
||||
simplify_and_verify(u)
|
||||
ret = self.assert_valid_ctx()
|
||||
self.assertEqual(len(ret), 1)
|
||||
key, ctx, metadata = ret[0][0]
|
||||
self.assertIs(key, u)
|
||||
self.assertIs(ctx.sink, u)
|
||||
self.assertEqual(len(metadata.upats), 1)
|
||||
|
||||
def test_track_rewrites(self):
|
||||
simple = TrackedPatternMatcher([(UPat.var("x")*1, lambda x:x)])
|
||||
@track_rewrites
|
||||
def do_rewrite(key:str, x:UOp): return graph_rewrite(x, simple)
|
||||
ld = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
|
||||
do_rewrite("uop_0", ld*1)
|
||||
do_rewrite("uop_1", ld*2)
|
||||
ret = self.assert_valid_ctx()
|
||||
self.assertEqual(len(ret), 1)
|
||||
key, _, m = ret[0][0]
|
||||
self.assertEqual(key, "uop_0")
|
||||
self.assertEqual(len(m.upats), 1)
|
||||
key, _, m = ret[0][1]
|
||||
self.assertEqual(key, "uop_1")
|
||||
self.assertEqual(len(m.upats), 0)
|
||||
|
||||
def test_dedup_ast(self):
|
||||
a = Tensor.empty(4, 4).contiguous().realize()+2
|
||||
b = Tensor.empty(4, 4).contiguous().realize()+2
|
||||
Tensor.schedule(a, b)
|
||||
kernels = self.assert_valid_ctx()
|
||||
self.assertEqual(len(kernels), 1)
|
||||
rewrites = [x[2] for x in kernels[0]]
|
||||
assert all(len(v) == 1 for k,v in group_rewrites(rewrites).items() if "schedule.py" in k)
|
||||
|
||||
def test_no_dedup_different_opts(self):
|
||||
a = Tensor.empty(4, 4)+Tensor.empty(4, 4)
|
||||
s = a.schedule()
|
||||
with Context(NOOPT=1): list(lower_schedule(s.copy()))
|
||||
with Context(NOOPT=0): list(lower_schedule(s.copy()))
|
||||
kernels = self.assert_valid_ctx()[1:]
|
||||
self.assertEqual(len(kernels), 2)
|
||||
rewrites = [x[2] for x in kernels[0]]
|
||||
assert all(len(v) == 1 for _,v in group_rewrites(rewrites).items())
|
||||
|
||||
def test_fold_const_nodes(self):
|
||||
a = Tensor.empty(4, 4)+2
|
||||
sink = a.schedule()[-1].ast
|
||||
ret = _uop_to_json(sink)
|
||||
assert not any(v[0].startswith("CONST") for v in ret.values())
|
||||
assert len([x for x in ret.values() if "CONST" in x[0]]) == 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
142
viz/test_viz.py
142
viz/test_viz.py
@@ -1,142 +0,0 @@
|
||||
from typing import Any, List, Tuple
|
||||
import unittest
|
||||
import os, itertools
|
||||
os.environ["TRACK_MATCH_STATS"] = "2"
|
||||
os.environ["PRINT_MATCH_STATS"] = "0"
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.engine.realize import lower_schedule
|
||||
from tinygrad.dtype import PtrDType
|
||||
from tinygrad.helpers import Context, all_same, getenv
|
||||
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, graph_rewrite, PatternMatcher, UPat, contexts, KernelInfo, BinaryOps, track_rewrites
|
||||
from tinygrad.codegen.uopgraph import sym, devectorize, float4_folding
|
||||
from viz.serve import GraphRewriteMetadata, get_metadata, get_details, _uop_to_json
|
||||
|
||||
def group_rewrites(kernels:List[GraphRewriteMetadata]): return {k:list(v) for k,v in itertools.groupby(kernels, lambda x:x.loc)}
|
||||
|
||||
class TestViz(unittest.TestCase):
|
||||
def tearDown(self) -> None:
|
||||
from tinygrad.ops import contexts
|
||||
if not getenv("VIZ"): contexts.clear()
|
||||
|
||||
def assert_valid_ctx(self, contexts:List[Tuple[Any,List[TrackedRewriteContext]]]):
|
||||
assert len(contexts) != 0
|
||||
get_metadata(contexts)
|
||||
|
||||
def assert_valid_graph(self, t):
|
||||
contexts.clear()
|
||||
s = t.schedule()
|
||||
list(lower_schedule(s))
|
||||
self.assert_valid_ctx(contexts)
|
||||
|
||||
def test_ctx_diff(self):
|
||||
a = Tensor.ones(4, 1).contiguous().realize()
|
||||
out = a + a.reshape(1, 4)
|
||||
self.assert_valid_graph(out)
|
||||
|
||||
def test_ctx_groups(self):
|
||||
contexts.clear()
|
||||
schedule1 = Tensor.zeros(4, 1).contiguous().exp().schedule()
|
||||
schedule2 = Tensor.zeros(4, 1).contiguous().exp().schedule()
|
||||
list(lower_schedule(schedule1))
|
||||
list(lower_schedule(schedule2))
|
||||
with Context(TRACK_MATCH_STATS=0): ret = get_metadata(contexts)
|
||||
assert len(ret) == 3
|
||||
assert all(len([x for _,_,x in y if "schedule" in x.loc[0]]) == 0 for y in ret[1:])
|
||||
assert all(len([x for _,_,x in y if "uopgraph" in x.loc[0]]) != 0 for y in ret[1:])
|
||||
|
||||
def test_gemm_diff(self):
|
||||
x = Tensor.empty(64, 64).realize()
|
||||
y = Tensor.empty(64, 64).realize()
|
||||
out = x.matmul(y)
|
||||
self.assert_valid_graph(out)
|
||||
|
||||
def test_removed_node(self):
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.int.vec(4), tuple((UOp.const(dtypes.int, 1),)*4))
|
||||
gep = UOp(UOps.GEP, dtypes.int, (vec,), (0,))
|
||||
sink = UOp(UOps.STORE, dtypes.void, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0), UOp.const(dtypes.int, 0), gep)).sink()
|
||||
pm = PatternMatcher([
|
||||
(UPat(UOps.VECTORIZE, name="root", src=(UPat(UOps.CONST, name="const"),), allow_any_len=True, location="test"),
|
||||
lambda root,const: UOp.const_like(root, const.arg) if all_same(root.src) else None),
|
||||
(UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="x"),), location="test"), lambda root,x: root.const_like(x.arg))
|
||||
])
|
||||
@track_rewrites
|
||||
def f(k): return graph_rewrite(sink, pm)
|
||||
ret = f("test_rewrite")
|
||||
self.assert_valid_ctx(contexts)
|
||||
args = get_metadata(contexts)[0][0]
|
||||
g = get_details(*args)
|
||||
assert g.graphs[-1] == _uop_to_json(ret)
|
||||
|
||||
def test_devectorize_viz(self):
|
||||
sink = UOp(UOps.SINK, dtypes.void, arg=KernelInfo(local_dims=1, upcasted=1, dont_use_locals=False), src=(
|
||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
|
||||
UOp(UOps.ALU, dtypes.int.vec(4), arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.VECTORIZE, dtypes.int.vec(4), arg=None, src=(
|
||||
x4:=UOp(UOps.ALU, dtypes.int, arg=BinaryOps.MUL, src=(
|
||||
x5:=UOp(UOps.SPECIAL, dtypes.int, arg=('lidx0', 4), src=()),
|
||||
UOp(UOps.CONST, dtypes.int, arg=4, src=()),)),
|
||||
x4,
|
||||
x4,
|
||||
x4,)),
|
||||
x7:=UOp(UOps.VCONST, dtypes.int.vec(4), arg=(0, 1, 2, 3), src=()),)),
|
||||
UOp(UOps.ALU, dtypes.float.vec(4), arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.VECTORIZE, dtypes.float.vec(4), arg=None, src=(
|
||||
x10:=UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
x11:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
|
||||
x5,)),
|
||||
x10,
|
||||
x10,
|
||||
x10,)),
|
||||
UOp(UOps.LOAD, dtypes.float.vec(4), arg=None, src=(
|
||||
x11,
|
||||
x7,)),)),)),))
|
||||
pm = sym+(devectorize+float4_folding)
|
||||
@track_rewrites
|
||||
def f(k): return graph_rewrite(sink, pm)
|
||||
f("test_rewrite")
|
||||
self.assert_valid_ctx(contexts)
|
||||
assert all(ctx.loc[0].split("/")[-1] == __file__.split("/")[-1] for _,ctxs in contexts for ctx in ctxs)
|
||||
|
||||
def test_no_ctx(self):
|
||||
simple_pm = PatternMatcher([(UPat(UOps.CONST), lambda:True)])
|
||||
simple_pm.rewrite(UOp.const(dtypes.int, 2))
|
||||
self.assertEqual(len(contexts), 0)
|
||||
|
||||
def test_dedup_ast(self):
|
||||
contexts.clear()
|
||||
a = Tensor.empty(4, 4).contiguous().realize()+2
|
||||
b = Tensor.empty(4, 4).contiguous().realize()+2
|
||||
Tensor.schedule(a, b)
|
||||
with Context(TRACK_MATCH_STATS=0): kernels = get_metadata(contexts)
|
||||
self.assertEqual(len(kernels), 1)
|
||||
rewrites = [x[2] for x in kernels[0]]
|
||||
assert all(len(v) == 1 for k,v in group_rewrites(rewrites).items() if "schedule.py" in k)
|
||||
|
||||
def test_no_dedup_different_opts(self):
|
||||
contexts.clear()
|
||||
a = Tensor.empty(4, 4)+Tensor.empty(4, 4)
|
||||
s = a.schedule()
|
||||
with Context(NOOPT=1): list(lower_schedule(s.copy()))
|
||||
with Context(NOOPT=0): list(lower_schedule(s.copy()))
|
||||
with Context(TRACK_MATCH_STATS=0): kernels = get_metadata(contexts)[1:]
|
||||
self.assertEqual(len(kernels), 2)
|
||||
rewrites = [x[2] for x in kernels[0]]
|
||||
assert all(len(v) == 1 for _,v in group_rewrites(rewrites).items())
|
||||
|
||||
def test_fold_const_nodes(self):
|
||||
a = Tensor.empty(4, 4)+2
|
||||
contexts.clear()
|
||||
sink = a.schedule()[-1].ast
|
||||
ret = _uop_to_json(sink)
|
||||
assert not any(v[0].startswith("CONST") for v in ret.values())
|
||||
assert len([x for x in ret.values() if "CONST" in x[0]]) == 1
|
||||
|
||||
@unittest.skip("VIZ for a single CONST isn't supported anymore")
|
||||
def test_no_fold_single_const(self):
|
||||
node = UOp(UOps.CONST, dtypes.float, (), 1.0)
|
||||
ret = _uop_to_json(node, base=node)
|
||||
assert len(ret) == 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user