diff --git a/tinygrad/codegen/symbolic.py b/tinygrad/codegen/symbolic.py index c354d484b1..b49ca66dd7 100644 --- a/tinygrad/codegen/symbolic.py +++ b/tinygrad/codegen/symbolic.py @@ -129,7 +129,7 @@ def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split if (y.op is not Ops.CONST) or ((c := y.arg) <= 0) or (x.dtype.count > 1): return None - svars, factors, quotients, remainders, gcd, div, const, offset, something_changed = [], [], [], [], c, 1, 0, 0, False + svars, factors, quotients, remainders, gcd, div, const, something_changed = [], [], [], [], c, 1, 0, False for u in split_uop(x, Ops.ADD): if u.op is Ops.MOD and which is Ops.MOD and u.src[1].op is Ops.CONST and u.src[1].arg%c == 0: u = u.src[0] @@ -137,32 +137,24 @@ def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split v: UOp = u.divides(f:=u.const_factor()) q, r = divmod(f, c) if r==0 or ((which is Ops.MOD or split_rem or u.op is Ops.CONST) and r!=f): something_changed = True - offset += r*v.vmin if u.op is Ops.CONST: const += f else: # div is the smallest common divisor of all terms if f > 1 and c % f == 0 and (div == 1 or div > f): div = f gcd = math.gcd(r, gcd) factors.append(f); svars.append(v); quotients.append(q); remainders.append(r) # noqa: E702 - lbound = ubound = offset = offset % c # we can fold if the expression has only one non-constant term and this term can only take on two values if len(svars)==1 and (v:=svars[0]).vmax-v.vmin == 1: - r = (offset+remainders[0])%c - offset%c - offset -= r * v.vmin - if which is Ops.MOD: return r*v + offset - return (factors[0]-r)//c * v + (const-offset)//c + y1 = (factors[0]*v.vmin+const)%c if which is Ops.MOD else (factors[0]*v.vmin+const)//c + y2 = (factors[0]*v.vmax+const)%c if which is Ops.MOD else (factors[0]*v.vmax+const)//c + return (y2-y1)*(v-v.vmin) + y1 # a//c = (a-a%c)/c, if we can fold a%c, we can fold a//c # within a mod we can freely subtract multiples of c, we use this to see if a is congruent to an expression whose vmin/vmax are between 0 and c - for (r, v) in zip(remainders, svars): - if r > c//2: - if (lbound := lbound + (r:=r-c) * (v.vmax-v.vmin)) < 0: break - elif (ubound := ubound + r * (v.vmax-v.vmin)) >= c: break - offset -= r * v.vmin # determine what the new offset would be - else: # vmin/vmax of the remainder is between 0 and c, we can remove the mod/div - remainders = [min(r, r-c, key=abs) for r in remainders] - if which is Ops.MOD: return functools.reduce(operator.add, [r*v for r,v in zip(remainders,svars)], x.const_like(offset)) - return functools.reduce(operator.add, [(f-r)//c * v for f,r,v in zip(factors, remainders,svars)], x.const_like((const-offset)//c)) + rems = [min(r, r-c, key=abs) for r in remainders] + if (rem:=sum(r*v for r,v in zip(rems,svars))+const%c).vmin//c==rem.vmax//c: + if which is Ops.MOD: return rem - rem.vmin//c*c + return sum((f-r)//c * v for f,r,v in zip(factors,rems,svars)) + (const-const%c+rem.vmin//c*c)//c if gcd != 1: something_changed = True if not something_changed: