mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
alu(c?t0:f0, c?t1:f1) -> c?alu(t0,t1):alu(f0,f1) (#7900)
* alu(c?t0:f0, c?t1:f1) -> c?alu(t0,t1):alu(f0,f1) only do if at least one branch is const, so total alu won't increase * tests and interesting TODO cases
This commit is contained in:
@@ -492,6 +492,30 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable(cond.where(u1, u0), 0, 1, "(a<2)")
|
||||
self.helper_test_variable(cond.where(u1, u0).where(u1, u0), 0, 1, "(a<2)")
|
||||
|
||||
def test_where_combine(self):
|
||||
cond = Variable("x", 0, 3).lt(2)
|
||||
a = Variable("a", 0, 3)
|
||||
b = Variable("b", 0, 3)
|
||||
aa = cond.where(a, a.ufix(0))
|
||||
bb = cond.where(b, b.ufix(1))
|
||||
self.helper_test_variable(aa, 0, 3, "(a if (x<2) else 0)")
|
||||
self.helper_test_variable(bb, 0, 3, "(b if (x<2) else 1)")
|
||||
self.helper_test_variable(aa+bb, 0, 6, "((a+b) if (x<2) else 1)")
|
||||
self.helper_test_variable(aa.maximum(bb), 0, 3, "(max(a, b) if (x<2) else 1)")
|
||||
|
||||
# not combining because it increased total ALU
|
||||
c = Variable("c", 0, 3)
|
||||
cc = cond.where(c, c+1)
|
||||
self.helper_test_variable(bb+cc, 0, 7, "((b if (x<2) else 1)+(c if (x<2) else (c+1)))")
|
||||
|
||||
# not combining # TODO: can combine if it can further simplify?
|
||||
ab = cond.where(a, b)
|
||||
ba = cond.where(b, a)
|
||||
self.helper_test_variable(ab+ba, 0, 6, "((a if (x<2) else b)+(b if (x<2) else a))")
|
||||
|
||||
# not combining # TODO: can combine if one is identity element const
|
||||
self.helper_test_variable(aa+ab, 0, 6, "((a if (x<2) else b)+(a if (x<2) else 0))")
|
||||
|
||||
class TestSymbolicNumeric(unittest.TestCase):
|
||||
def helper_test_numeric(self, f):
|
||||
MIN, MAX = 0, 10
|
||||
|
||||
@@ -1132,6 +1132,9 @@ symbolic = symbolic_simple+PatternMatcher([
|
||||
# a conditional with the same results either way is a noop, also fold const conditionals
|
||||
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
|
||||
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
|
||||
# alu of two where with same conds can combine, only do if true branch or false branch is const
|
||||
(UPat(GroupOp.Binary, name="alu", src=(UPat.var("c").where(UPat.var("t"), UPat.var("f")), UPat.var("c").where(UPat.var("tt"), UPat.var("ff")))), \
|
||||
lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
|
||||
# ALU min==max -> CONST (slow!)
|
||||
(UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
||||
# max folding
|
||||
|
||||
Reference in New Issue
Block a user