move mod mod pattern into generic mod folding (#6077)

This commit is contained in:
chenyu
2024-08-14 16:24:21 -04:00
committed by GitHub
parent 64563abc90
commit a61cb1ff7c

View File

@@ -91,6 +91,9 @@ def mod_folding(x:UOp, c:int) -> Optional[UOp]:
if (factor:=u.const_factor())%c != factor:
remainder.append(u.divides(factor)*(factor%c))
something_changed = True
elif u.op is UOps.ALU and u.arg is BinaryOps.MOD and (s1:=u.src[1]).op is UOps.CONST and s1.arg%c == 0:
remainder.append(u.src[0])
something_changed = True
else: remainder.append(u)
if not something_changed: return None
return functools.reduce(operator.add, remainder) if remainder else x.const(0)
@@ -292,8 +295,6 @@ constant_folder = PatternMatcher([
x-(x.vmin.arg//c.arg)*c.arg if 0 < c.arg and 0 <= x.vmin.arg and x.vmin.arg//c.arg == x.vmax.arg//c.arg else None),
# mul mod
((NOp.cvar('c0')*NOp.var('x')) % NOp.cvar('c1'), lambda x,c0,c1: (x%(c1.arg//c0.arg))*c0 if c1.arg%c0.arg == 0 else None),
# mod mod
((NOp.var('x') % NOp.cvar('c0')) % NOp.cvar('c1'), lambda x,c0,c1: x % c1 if c0.arg % c1.arg == 0 else None),
# (x%c)+(x//c)*c = x
(NOp.var('x')%NOp.cvar('c')+(NOp.var('x')//NOp.cvar('c'))*NOp.cvar('c'), lambda x,c: x),
# ** combine terms **