mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
155 lines
6.3 KiB
Python
155 lines
6.3 KiB
Python
import unittest
|
|
import os, itertools
|
|
os.environ["TRACK_MATCH_STATS"] = "2"
|
|
from extra.models.resnet import ResNet50
|
|
from tinygrad import Tensor
|
|
from tinygrad.engine.realize import lower_schedule
|
|
from tinygrad.ops import UOp, UOps, graph_rewrite, PatternMatcher, UPat, contexts, KernelInfo, BinaryOps
|
|
from tinygrad.dtype import dtypes, PtrDType
|
|
from tinygrad.helpers import CI, Context, all_same, DEBUG, colored, getenv
|
|
from tinygrad.codegen.uopgraph import constant_folder, devectorize, float4_folding
|
|
from test.external.process_replay.helpers import print_diff
|
|
from viz.serve import KernelRet, UOpRet, load_kernels, uop_to_json
|
|
|
|
def group_rewrites(kernels:KernelRet): return {k:list(v) for k,v in itertools.groupby(kernels.ctxs.values(), lambda x:x.loc)}
|
|
|
|
class TestViz(unittest.TestCase):
|
|
def tearDown(self) -> None:
|
|
from tinygrad.ops import contexts
|
|
if not getenv("VIZ"): contexts.clear()
|
|
|
|
def assert_valid_ctx(self, contexts):
|
|
assert len(contexts) != 0
|
|
for i,ctx in enumerate(contexts):
|
|
try: ret = UOpRet.from_ctx(ctx)
|
|
except Exception as e:
|
|
print(colored(f"failed to create graph for ctx {i}", "red"))
|
|
raise e
|
|
for j,(x,y) in enumerate(zip(ret.graphs, ret.graphs[1:])):
|
|
if x.key == y.key:
|
|
raise AssertionError(f"failed to generate the correct diff at rewrite {j} ctx {i}")
|
|
|
|
def assert_valid_graph(self, t):
|
|
contexts.clear()
|
|
s = t.schedule()
|
|
list(lower_schedule(s))
|
|
self.assert_valid_ctx(contexts)
|
|
|
|
def test_ctx_diff(self):
|
|
a = Tensor.ones(4, 1).contiguous().realize()
|
|
out = a + a.reshape(1, 4)
|
|
self.assert_valid_graph(out)
|
|
|
|
def test_ctx_groups(self):
|
|
contexts.clear()
|
|
schedule1 = Tensor.zeros(4, 1).contiguous().exp().schedule()
|
|
schedule2 = Tensor.zeros(4, 1).contiguous().exp().schedule()
|
|
list(lower_schedule(schedule1))
|
|
list(lower_schedule(schedule2))
|
|
ret = load_kernels(contexts)
|
|
assert len(ret) == 2
|
|
assert all(len([x for x in y.ctxs.values() if "schedule" in x.loc[0]]) != 0 for y in ret)
|
|
assert all(len([x for x in y.ctxs.values() if "uopgraph" in x.loc[0]]) != 0 for y in ret)
|
|
|
|
def test_gemm_diff(self):
|
|
x = Tensor.empty(64, 64).realize()
|
|
y = Tensor.empty(64, 64).realize()
|
|
out = x.matmul(y)
|
|
self.assert_valid_graph(out)
|
|
|
|
def test_removed_node(self):
|
|
vec = UOp(UOps.VECTORIZE, dtypes.int.vec(4), tuple((UOp.const(dtypes.int, 1),)*4))
|
|
gep = UOp(UOps.GEP, dtypes.int, (vec,), (0,))
|
|
sink = UOp(UOps.STORE, dtypes.void, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0), UOp.const(dtypes.int, 0), gep)).sink()
|
|
pm = PatternMatcher([
|
|
(UPat(UOps.VECTORIZE, name="root", src=(UPat(UOps.CONST, name="const"),), allow_any_len=True, location="test"),
|
|
lambda root,const: UOp.const_like(root, const.arg) if all_same(root.src) else None),
|
|
(UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="x"),), location="test"), lambda root,x: root.const_like(x.arg))
|
|
])
|
|
ret = graph_rewrite(sink, pm)
|
|
if DEBUG >= 4: print_diff(sink, ret)
|
|
g = UOpRet.from_ctx(contexts[0])
|
|
assert g.graphs[-1].key == ret.key
|
|
self.assert_valid_ctx(contexts)
|
|
|
|
def test_devectorize_viz(self):
|
|
sink = UOp(UOps.SINK, dtypes.void, arg=KernelInfo(local_dims=1, upcasted=1, dont_use_locals=False), src=(
|
|
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
|
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
|
|
UOp(UOps.ALU, dtypes.int.vec(4), arg=BinaryOps.ADD, src=(
|
|
UOp(UOps.VECTORIZE, dtypes.int.vec(4), arg=None, src=(
|
|
x4:=UOp(UOps.ALU, dtypes.int, arg=BinaryOps.MUL, src=(
|
|
x5:=UOp(UOps.SPECIAL, dtypes.int, arg=('lidx0', 4), src=()),
|
|
UOp(UOps.CONST, dtypes.int, arg=4, src=()),)),
|
|
x4,
|
|
x4,
|
|
x4,)),
|
|
x7:=UOp(UOps.VCONST, dtypes.int.vec(4), arg=(0, 1, 2, 3), src=()),)),
|
|
UOp(UOps.ALU, dtypes.float.vec(4), arg=BinaryOps.ADD, src=(
|
|
UOp(UOps.VECTORIZE, dtypes.float.vec(4), arg=None, src=(
|
|
x10:=UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
|
x11:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
|
|
x5,)),
|
|
x10,
|
|
x10,
|
|
x10,)),
|
|
UOp(UOps.LOAD, dtypes.float.vec(4), arg=None, src=(
|
|
x11,
|
|
x7,)),)),)),))
|
|
pm = constant_folder+(devectorize+float4_folding)
|
|
new_sink = graph_rewrite(sink, pm)
|
|
if DEBUG >= 4: print_diff(sink, new_sink, unified=0)
|
|
self.assert_valid_ctx(contexts)
|
|
assert all(ctx.loc[0].split("/")[-1] == __file__.split("/")[-1] for ctx in contexts)
|
|
|
|
@unittest.skipIf(CI, "slow, it's generating diffs for 36202 rules")
|
|
def test_fuzz_resnet(self):
|
|
mdl = ResNet50()
|
|
img = Tensor.empty(64, 3, 224, 224)
|
|
out = mdl(img)
|
|
sched = out.schedule()
|
|
list(lower_schedule(sched))
|
|
self.assert_valid_ctx(contexts)
|
|
|
|
def test_no_ctx(self):
|
|
simple_pm = PatternMatcher([(UPat(UOps.CONST), lambda:True)])
|
|
simple_pm.rewrite(UOp.const(dtypes.int, 2))
|
|
self.assertEqual(len(contexts), 0)
|
|
|
|
def test_dedup_ast(self):
|
|
contexts.clear()
|
|
a = Tensor.empty(4, 4).contiguous().realize()+2
|
|
b = Tensor.empty(4, 4).contiguous().realize()+2
|
|
Tensor.schedule(a, b)
|
|
kernels = load_kernels(contexts)
|
|
self.assertEqual(len(kernels), 1)
|
|
assert all(len(v) == 1 for k,v in group_rewrites(kernels[0]).items() if "schedule.py" in k)
|
|
|
|
def test_no_dedup_different_opts(self):
|
|
contexts.clear()
|
|
a = Tensor.empty(4, 4)+Tensor.empty(4, 4)
|
|
s = a.schedule()
|
|
with Context(NOOPT=1): list(lower_schedule(s.copy()))
|
|
with Context(NOOPT=0): list(lower_schedule(s.copy()))
|
|
kernels = load_kernels(contexts)
|
|
self.assertEqual(len(kernels), 2)
|
|
assert all(len(v) == 1 for _,v in group_rewrites(kernels[0]).items())
|
|
assert all(len(v) == 0 for k,v in group_rewrites(kernels[1]).items() if "schedule.py" in k)
|
|
|
|
def test_fold_const_nodes(self):
|
|
a = Tensor.empty(4, 4)+2
|
|
contexts.clear()
|
|
sink = a.schedule()[-1].ast
|
|
ret = uop_to_json(sink)
|
|
for v in ret.values(): print(v)
|
|
assert not any(v[0].startswith("CONST") for v in ret.values())
|
|
assert len([x for x in ret.values() if "CONST" in x[0]]) == 1
|
|
|
|
def test_no_fold_single_const(self):
|
|
node = UOp(UOps.CONST, dtypes.float, (), 1.0)
|
|
ret = uop_to_json(node)
|
|
assert len(ret) == 1
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|