mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
alu(c?t0:f0, c?t1:f1) -> c?alu(t0,t1):alu(f0,f1) (#7900)
* alu(c?t0:f0, c?t1:f1) -> c?alu(t0,t1):alu(f0,f1) only do if at least one branch is const, so total alu won't increase * tests and interesting TODO cases
This commit is contained in:
@@ -492,6 +492,30 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable(cond.where(u1, u0), 0, 1, "(a<2)")
|
||||
self.helper_test_variable(cond.where(u1, u0).where(u1, u0), 0, 1, "(a<2)")
|
||||
|
||||
def test_where_combine(self):
|
||||
cond = Variable("x", 0, 3).lt(2)
|
||||
a = Variable("a", 0, 3)
|
||||
b = Variable("b", 0, 3)
|
||||
aa = cond.where(a, a.ufix(0))
|
||||
bb = cond.where(b, b.ufix(1))
|
||||
self.helper_test_variable(aa, 0, 3, "(a if (x<2) else 0)")
|
||||
self.helper_test_variable(bb, 0, 3, "(b if (x<2) else 1)")
|
||||
self.helper_test_variable(aa+bb, 0, 6, "((a+b) if (x<2) else 1)")
|
||||
self.helper_test_variable(aa.maximum(bb), 0, 3, "(max(a, b) if (x<2) else 1)")
|
||||
|
||||
# not combining because it increased total ALU
|
||||
c = Variable("c", 0, 3)
|
||||
cc = cond.where(c, c+1)
|
||||
self.helper_test_variable(bb+cc, 0, 7, "((b if (x<2) else 1)+(c if (x<2) else (c+1)))")
|
||||
|
||||
# not combining # TODO: can combine if it can further simplify?
|
||||
ab = cond.where(a, b)
|
||||
ba = cond.where(b, a)
|
||||
self.helper_test_variable(ab+ba, 0, 6, "((a if (x<2) else b)+(b if (x<2) else a))")
|
||||
|
||||
# not combining # TODO: can combine if one is identity element const
|
||||
self.helper_test_variable(aa+ab, 0, 6, "((a if (x<2) else b)+(a if (x<2) else 0))")
|
||||
|
||||
class TestSymbolicNumeric(unittest.TestCase):
|
||||
def helper_test_numeric(self, f):
|
||||
MIN, MAX = 0, 10
|
||||
|
||||
Reference in New Issue
Block a user