UOp if min==max folds to CONST (#5828)

* UOp if min==max folds to CONST

* fix test
This commit is contained in:
chenyu
2024-07-30 22:14:22 -04:00
committed by GitHub
parent 4e89d45513
commit c3da458bc3
2 changed files with 9 additions and 8 deletions

View File

@@ -78,17 +78,17 @@ class TestGraphRewrite(unittest.TestCase):
self.assertEqual(nout.src[1].arg, 3.0)
def test_consts_go_last(self):
a = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('a', 0, 1))
b = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('b', 0, 1))
c = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('c', 0, 1))
d = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('d', 0, 1))
outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const(2),a), arg=BinaryOps.ADD)]
a = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('a', 0, 1))
b = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('b', 0, 1))
c = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('c', 0, 1))
d = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('d', 0, 1))
outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const(2), a), arg=BinaryOps.ADD)]
for out in outs:
sink = graph_rewrite(out, constant_folder)
print(sink)
self.assertEqual(sink.op, UOps.ALU)
self.assertEqual(sink.src[1].op, UOps.CONST)
self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 1)
self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 3)
class TestUOpGraph(TestUOps):
def test_add_constant_fold(self):
@@ -213,13 +213,13 @@ class TestUOpGraph(TestUOps):
self.assertEqual(len([x for x in g.uops if x.op is UOps.CAST]), 1)
def test_depth_2_const_fold(self):
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('tmp', 0, 1))
v = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('tmp', 0, 1))
c2 = UOp(UOps.CONST, dtypes.int, arg=2)
c4 = UOp(UOps.CONST, dtypes.int, arg=4)
vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
out = UOp(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD)
g = UOpGraph([out])
self.assertEqual(len(g.uops), 3)
self.assertEqual(len(g.uops), 5)
out = g.uops[-1]
self.assertEqual(out.op, UOps.ALU)
self.assertEqual(out.arg, BinaryOps.ADD)