diff --git a/tinygrad/ops.py b/tinygrad/ops.py index b9846978f8..abb8cefa6c 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -160,6 +160,9 @@ class GroupOp: # BinaryOps that can be flipped Commutative = {Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPNE, Ops.XOR, Ops.AND, Ops.OR} + # BinaryOps that satisfy f(x,x)=x see https://en.wikipedia.org/wiki/Idempotence + Idempotent = {Ops.OR, Ops.AND, Ops.MAX} + # do not preserve f(0) = 0 UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV} @@ -1186,9 +1189,7 @@ symbolic_simple = PatternMatcher([ lambda x,c1,c2,c3: x*c2 if c1.arg*c2.arg==c3.arg else None), # (x%c1)*c2+(x//c1)*c3 = x*c2 if c1*c2==c3 (UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c), (UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x), - (UPat.var("x").maximum(UPat.var("x")), lambda x: x), - ((UPat.var("x") & UPat.var("x")), lambda x: x), - ((UPat.var("x") | UPat.var("x")), lambda x: x), + (UPat(GroupOp.Idempotent, src=(UPat.var("x"), UPat.var("x"))), lambda x: x), (UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x), (UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x), # ** zero folding **