simpler commutative flipping condition (#7377)

`x.src[1].tuplize < x.src[0].tuplize` implies `x.src[0] is not x.src[1]`

also renamed cc -> op
This commit is contained in:
chenyu
2024-10-29 13:51:24 -04:00
committed by GitHub
parent d3c192b056
commit 07ad6d20ed

View File

@@ -1035,8 +1035,7 @@ symbolic = PatternMatcher([
(UPat(UOps.ALU, name="root", src=UPat((UOps.VCONST, UOps.CONST))),
lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src], truncate_output=False))),
# ** COMMUTATIVE flipping **
*[(UPat(UOps.ALU, arg=cc, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[0] is not x.src[1] \
and x.src[1].tuplize < x.src[0].tuplize else None) for cc in COMMUTATIVE],
*[(UPat(UOps.ALU, arg=op, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None) for op in COMMUTATIVE],
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y'), lambda x,y: x&y),
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y'), lambda x,y: x|y),