mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
move graph_dedup out of class [run_process_replay] (#5197)
This commit is contained in:
@@ -3,7 +3,7 @@ from test.helpers import TestUOps
|
||||
from tinygrad import dtypes, Variable
|
||||
from tinygrad.dtype import PtrDType
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps
|
||||
from tinygrad.codegen.uops import UOpGraph, UOps, UOp, PatternMatcher, graph_rewrite
|
||||
from tinygrad.codegen.uops import UOpGraph, UOps, UOp, PatternMatcher, graph_rewrite, graph_dedup
|
||||
#from tinygrad.engine.graph import print_tree
|
||||
|
||||
simple_pm = PatternMatcher([
|
||||
@@ -14,6 +14,12 @@ simple_pm = PatternMatcher([
|
||||
])
|
||||
|
||||
class TestGraphRewrite(unittest.TestCase):
|
||||
def test_dedup(self):
|
||||
v1 = UOp(UOps.DEFINE_VAR, dtypes.float)
|
||||
v2 = UOp(UOps.DEFINE_VAR, dtypes.float)
|
||||
nout = graph_dedup(v1+v2)
|
||||
self.assertIs(nout.src[0], nout.src[1])
|
||||
|
||||
def test_simple(self):
|
||||
c1 = UOp.const(dtypes.float, 1.0)
|
||||
c2 = UOp.const(dtypes.float, 2.0)
|
||||
|
||||
Reference in New Issue
Block a user