UOp if min==max folds to CONST (#5828)

* UOp if min==max folds to CONST

* fix test
This commit is contained in:
chenyu
2024-07-30 22:14:22 -04:00
committed by GitHub
parent 4e89d45513
commit c3da458bc3
2 changed files with 9 additions and 8 deletions

View File

@@ -78,17 +78,17 @@ class TestGraphRewrite(unittest.TestCase):
self.assertEqual(nout.src[1].arg, 3.0)
def test_consts_go_last(self):
a = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('a', 0, 1))
b = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('b', 0, 1))
c = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('c', 0, 1))
d = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('d', 0, 1))
outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const(2),a), arg=BinaryOps.ADD)]
a = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('a', 0, 1))
b = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('b', 0, 1))
c = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('c', 0, 1))
d = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('d', 0, 1))
outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const(2), a), arg=BinaryOps.ADD)]
for out in outs:
sink = graph_rewrite(out, constant_folder)
print(sink)
self.assertEqual(sink.op, UOps.ALU)
self.assertEqual(sink.src[1].op, UOps.CONST)
self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 1)
self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 3)
class TestUOpGraph(TestUOps):
def test_add_constant_fold(self):
@@ -213,13 +213,13 @@ class TestUOpGraph(TestUOps):
self.assertEqual(len([x for x in g.uops if x.op is UOps.CAST]), 1)
def test_depth_2_const_fold(self):
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('tmp', 0, 1))
v = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('tmp', 0, 1))
c2 = UOp(UOps.CONST, dtypes.int, arg=2)
c4 = UOp(UOps.CONST, dtypes.int, arg=4)
vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
out = UOp(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD)
g = UOpGraph([out])
self.assertEqual(len(g.uops), 3)
self.assertEqual(len(g.uops), 5)
out = g.uops[-1]
self.assertEqual(out.op, UOps.ALU)
self.assertEqual(out.arg, BinaryOps.ADD)

View File

@@ -209,6 +209,7 @@ constant_folder = PatternMatcher([
# NOTE: this can be wrong for loaded NaN
(NOp.var('x') * 0, lambda x: x.const(float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
(NOp.var('x') - NOp.var('x'), lambda x: x.const(0)), # x-x -> 0
(UPat(op=UOps.ALU, name='x'), lambda x: x.const(x.vmin.arg) if x.op is not UOps.CONST and x.vmin.arg == x.vmax.arg else None),
# lt folding
(NOp.var('x').lt(NOp.var('y')),
lambda x,y: NOp.const(dtypes.bool, True) if x.vmax.arg < y.vmin.arg else NOp.const(dtypes.bool, False) if x.vmin.arg >= y.vmax.arg else None),