x|(x&y) -> x (#7629)

* x|(x&y) -> x

* fix tests
This commit is contained in:
George Hotz
2024-11-11 10:00:18 +08:00
committed by GitHub
parent 94a484542b
commit bbc64bf305
3 changed files with 4 additions and 2 deletions

View File

@@ -94,7 +94,7 @@ class TestLinearizerDumb(unittest.TestCase):
prg = k.to_program()
print(prg.src)
if_uops = [u for u in k.uops if u.op is Ops.IF]
self.assertIn(len(if_uops), {1,3})
self.assertIn(len(if_uops), {1,2,3})
conditions = if_uops[0].src[0].sparents
self.assertLessEqual(len(conditions), 9)

View File

@@ -1043,7 +1043,7 @@ class TestLinearizerFailures(unittest.TestCase):
k = helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
assert k is not None
ifs = [u for u in k.uops if u.op is Ops.IF]
self.assertEqual(len(ifs), 4)
self.assertEqual(len(ifs), 3)
#for st in k.uops.sink.src: self.assertEqual(len(st.src), 4)
self.assertLessEqual(len(ifs[0].src[0].sparents), 17)

View File

@@ -1055,6 +1055,8 @@ symbolic = symbolic_simple+PatternMatcher([
(UPat(GroupOp.Commutative, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
# group like
((UPat.var("x") + UPat.var("y")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: (x+x*c)+y),
# ** boolean algebra **
(UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
# ** combine terms **
(UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
(UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)