mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
* runs one metal kernel * conv2d works * ops tests are passing * const folding * all ops work * pre commit always passes * torch works * working still * fix graph test * tests passing * image almost works * image conv works * most images * fix custom * fix assignment * fix compile enet * clean up comments * fix realize return value * include shapetracker in LB repr * copy should make a copy * reenable method cache * fix lna * dtypes in graph * forward only for IMAGE=2 * simple realize * getting close * fixup new api, it's good except the kernel count * back to 197 kernels * tests should pass * go to a real float * no type_on_cpu * fix the docs * put shapetracker back in it's proper place
74 lines
2.0 KiB
Python
74 lines
2.0 KiB
Python
#!/usr/bin/env python
|
|
import unittest
|
|
import networkx as nx # type: ignore
|
|
from tinygrad.tensor import Tensor
|
|
from tinygrad.graph import G, log_op, prune_graph
|
|
from tinygrad.ops import BinaryOps, LazyOp, MovementOps, ReduceOps
|
|
|
|
def buf(*shp): return Tensor.ones(*shp, device="CPU").lazydata
|
|
|
|
class TestGraph(unittest.TestCase):
|
|
def setUp(self):
|
|
G.clear()
|
|
|
|
def helper_compare_graph(self, RG: nx.DiGraph):
|
|
assert nx.is_isomorphic(G, RG, node_match=lambda x,y: x["label"] == y["label"], edge_match=lambda x,y: x["label"] == y["label"] if "label" in y else True)
|
|
|
|
def test_add_graph(self):
|
|
a = buf(4,4)
|
|
b = buf(4,4)
|
|
ast = LazyOp(BinaryOps.ADD, (a,b))
|
|
ret = buf(4,4)
|
|
|
|
RG = nx.DiGraph()
|
|
RG.add_node(0, label="(4, 4)")
|
|
RG.add_node(1, label="(4, 4)")
|
|
RG.add_node(2, label="(4, 4)")
|
|
RG.add_edge(0, 2, label="ADD")
|
|
RG.add_edge(1, 2, label="ADD")
|
|
|
|
log_op(ret, ast, show_graph=True)
|
|
self.helper_compare_graph(RG)
|
|
|
|
def test_add_sum_graph(self):
|
|
a = buf(4,4)
|
|
b = buf(1,1)
|
|
op0 = LazyOp(MovementOps.RESHAPE, (b,), (4, 4))
|
|
op1 = LazyOp(BinaryOps.ADD, (a,op0))
|
|
ast = LazyOp(ReduceOps.SUM, (op1,), (1,1))
|
|
ret = buf(1,1)
|
|
|
|
RG = nx.DiGraph()
|
|
RG.add_node(0, label="(4, 4)")
|
|
RG.add_node(1, label="(1, 1)")
|
|
RG.add_node(2, label="{(4, 4), (1, 1)}\n(1, 1)")
|
|
RG.add_edge(0, 2, label="RES.ADD.SUM")
|
|
RG.add_edge(1, 2, label="RES.ADD.SUM")
|
|
|
|
log_op(ret, ast, show_graph=True)
|
|
self.helper_compare_graph(RG)
|
|
|
|
def test_add_graph_prune(self):
|
|
a = buf(1,1)
|
|
ast = LazyOp(MovementOps.RESHAPE, (a,), (4, 4))
|
|
ret = buf(4,4)
|
|
log_op(ret, ast, show_graph=True)
|
|
|
|
b = buf(4,4)
|
|
ast = LazyOp(BinaryOps.ADD, (ret,b))
|
|
ret = buf(4,4)
|
|
log_op(ret, ast, show_graph=True)
|
|
prune_graph()
|
|
|
|
RG = nx.DiGraph()
|
|
RG.add_node(0, label="(1, 1)")
|
|
RG.add_node(1, label="(4, 4)")
|
|
RG.add_node(2, label="(4, 4)")
|
|
RG.add_edge(0, 2) # edge connecting pruned nodes
|
|
RG.add_edge(1, 2, label="ADD")
|
|
|
|
self.helper_compare_graph(RG)
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|