diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 067fd0af79..3f73433226 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -91,6 +91,9 @@ def mod_folding(x:UOp, c:int) -> Optional[UOp]: if (factor:=u.const_factor())%c != factor: remainder.append(u.divides(factor)*(factor%c)) something_changed = True + elif u.op is UOps.ALU and u.arg is BinaryOps.MOD and (s1:=u.src[1]).op is UOps.CONST and s1.arg%c == 0: + remainder.append(u.src[0]) + something_changed = True else: remainder.append(u) if not something_changed: return None return functools.reduce(operator.add, remainder) if remainder else x.const(0) @@ -292,8 +295,6 @@ constant_folder = PatternMatcher([ x-(x.vmin.arg//c.arg)*c.arg if 0 < c.arg and 0 <= x.vmin.arg and x.vmin.arg//c.arg == x.vmax.arg//c.arg else None), # mul mod ((NOp.cvar('c0')*NOp.var('x')) % NOp.cvar('c1'), lambda x,c0,c1: (x%(c1.arg//c0.arg))*c0 if c1.arg%c0.arg == 0 else None), - # mod mod - ((NOp.var('x') % NOp.cvar('c0')) % NOp.cvar('c1'), lambda x,c0,c1: x % c1 if c0.arg % c1.arg == 0 else None), # (x%c)+(x//c)*c = x (NOp.var('x')%NOp.cvar('c')+(NOp.var('x')//NOp.cvar('c'))*NOp.cvar('c'), lambda x,c: x), # ** combine terms **