alu(c?t0:f0, c?t1:f1) -> c?alu(t0,t1):alu(f0,f1) (#7900)

* alu(c?t0:f0, c?t1:f1) -> c?alu(t0,t1):alu(f0,f1)

only do if at least one branch is const, so total alu won't increase

* tests and interesting TODO cases
This commit is contained in:
chenyu
2024-12-02 17:19:27 -05:00
committed by GitHub
parent b91fa24387
commit c7bc75e634
2 changed files with 27 additions and 0 deletions

View File

@@ -492,6 +492,30 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(cond.where(u1, u0), 0, 1, "(a<2)")
self.helper_test_variable(cond.where(u1, u0).where(u1, u0), 0, 1, "(a<2)")
def test_where_combine(self):
cond = Variable("x", 0, 3).lt(2)
a = Variable("a", 0, 3)
b = Variable("b", 0, 3)
aa = cond.where(a, a.ufix(0))
bb = cond.where(b, b.ufix(1))
self.helper_test_variable(aa, 0, 3, "(a if (x<2) else 0)")
self.helper_test_variable(bb, 0, 3, "(b if (x<2) else 1)")
self.helper_test_variable(aa+bb, 0, 6, "((a+b) if (x<2) else 1)")
self.helper_test_variable(aa.maximum(bb), 0, 3, "(max(a, b) if (x<2) else 1)")
# not combining because it increased total ALU
c = Variable("c", 0, 3)
cc = cond.where(c, c+1)
self.helper_test_variable(bb+cc, 0, 7, "((b if (x<2) else 1)+(c if (x<2) else (c+1)))")
# not combining # TODO: can combine if it can further simplify?
ab = cond.where(a, b)
ba = cond.where(b, a)
self.helper_test_variable(ab+ba, 0, 6, "((a if (x<2) else b)+(b if (x<2) else a))")
# not combining # TODO: can combine if one is identity element const
self.helper_test_variable(aa+ab, 0, 6, "((a if (x<2) else b)+(a if (x<2) else 0))")
class TestSymbolicNumeric(unittest.TestCase):
def helper_test_numeric(self, f):
MIN, MAX = 0, 10