mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-17 18:11:49 -05:00
UOp if min==max folds to CONST (#5828)
* UOp if min==max folds to CONST * fix test
This commit is contained in:
@@ -78,17 +78,17 @@ class TestGraphRewrite(unittest.TestCase):
|
||||
self.assertEqual(nout.src[1].arg, 3.0)
|
||||
|
||||
def test_consts_go_last(self):
|
||||
a = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('a', 0, 1))
|
||||
b = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('b', 0, 1))
|
||||
c = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('c', 0, 1))
|
||||
d = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('d', 0, 1))
|
||||
outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const(2),a), arg=BinaryOps.ADD)]
|
||||
a = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('a', 0, 1))
|
||||
b = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('b', 0, 1))
|
||||
c = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('c', 0, 1))
|
||||
d = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('d', 0, 1))
|
||||
outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const(2), a), arg=BinaryOps.ADD)]
|
||||
for out in outs:
|
||||
sink = graph_rewrite(out, constant_folder)
|
||||
print(sink)
|
||||
self.assertEqual(sink.op, UOps.ALU)
|
||||
self.assertEqual(sink.src[1].op, UOps.CONST)
|
||||
self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 1)
|
||||
self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 3)
|
||||
|
||||
class TestUOpGraph(TestUOps):
|
||||
def test_add_constant_fold(self):
|
||||
@@ -213,13 +213,13 @@ class TestUOpGraph(TestUOps):
|
||||
self.assertEqual(len([x for x in g.uops if x.op is UOps.CAST]), 1)
|
||||
|
||||
def test_depth_2_const_fold(self):
|
||||
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('tmp', 0, 1))
|
||||
v = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('tmp', 0, 1))
|
||||
c2 = UOp(UOps.CONST, dtypes.int, arg=2)
|
||||
c4 = UOp(UOps.CONST, dtypes.int, arg=4)
|
||||
vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
|
||||
out = UOp(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD)
|
||||
g = UOpGraph([out])
|
||||
self.assertEqual(len(g.uops), 3)
|
||||
self.assertEqual(len(g.uops), 5)
|
||||
out = g.uops[-1]
|
||||
self.assertEqual(out.op, UOps.ALU)
|
||||
self.assertEqual(out.arg, BinaryOps.ADD)
|
||||
|
||||
Reference in New Issue
Block a user