mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 13:28:06 -05:00
Cleanup in div_and_mod_folding [pr] (#10178)
* Refactor binary var simplification * Simplify the congruence logic --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -129,7 +129,7 @@ def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split
|
||||
|
||||
if (y.op is not Ops.CONST) or ((c := y.arg) <= 0) or (x.dtype.count > 1): return None
|
||||
|
||||
svars, factors, quotients, remainders, gcd, div, const, offset, something_changed = [], [], [], [], c, 1, 0, 0, False
|
||||
svars, factors, quotients, remainders, gcd, div, const, something_changed = [], [], [], [], c, 1, 0, False
|
||||
for u in split_uop(x, Ops.ADD):
|
||||
if u.op is Ops.MOD and which is Ops.MOD and u.src[1].op is Ops.CONST and u.src[1].arg%c == 0:
|
||||
u = u.src[0]
|
||||
@@ -137,32 +137,24 @@ def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split
|
||||
v: UOp = u.divides(f:=u.const_factor())
|
||||
q, r = divmod(f, c)
|
||||
if r==0 or ((which is Ops.MOD or split_rem or u.op is Ops.CONST) and r!=f): something_changed = True
|
||||
offset += r*v.vmin
|
||||
if u.op is Ops.CONST: const += f
|
||||
else: # div is the smallest common divisor of all terms
|
||||
if f > 1 and c % f == 0 and (div == 1 or div > f): div = f
|
||||
gcd = math.gcd(r, gcd)
|
||||
factors.append(f); svars.append(v); quotients.append(q); remainders.append(r) # noqa: E702
|
||||
|
||||
lbound = ubound = offset = offset % c
|
||||
# we can fold if the expression has only one non-constant term and this term can only take on two values
|
||||
if len(svars)==1 and (v:=svars[0]).vmax-v.vmin == 1:
|
||||
r = (offset+remainders[0])%c - offset%c
|
||||
offset -= r * v.vmin
|
||||
if which is Ops.MOD: return r*v + offset
|
||||
return (factors[0]-r)//c * v + (const-offset)//c
|
||||
y1 = (factors[0]*v.vmin+const)%c if which is Ops.MOD else (factors[0]*v.vmin+const)//c
|
||||
y2 = (factors[0]*v.vmax+const)%c if which is Ops.MOD else (factors[0]*v.vmax+const)//c
|
||||
return (y2-y1)*(v-v.vmin) + y1
|
||||
|
||||
# a//c = (a-a%c)/c, if we can fold a%c, we can fold a//c
|
||||
# within a mod we can freely subtract multiples of c, we use this to see if a is congruent to an expression whose vmin/vmax are between 0 and c
|
||||
for (r, v) in zip(remainders, svars):
|
||||
if r > c//2:
|
||||
if (lbound := lbound + (r:=r-c) * (v.vmax-v.vmin)) < 0: break
|
||||
elif (ubound := ubound + r * (v.vmax-v.vmin)) >= c: break
|
||||
offset -= r * v.vmin # determine what the new offset would be
|
||||
else: # vmin/vmax of the remainder is between 0 and c, we can remove the mod/div
|
||||
remainders = [min(r, r-c, key=abs) for r in remainders]
|
||||
if which is Ops.MOD: return functools.reduce(operator.add, [r*v for r,v in zip(remainders,svars)], x.const_like(offset))
|
||||
return functools.reduce(operator.add, [(f-r)//c * v for f,r,v in zip(factors, remainders,svars)], x.const_like((const-offset)//c))
|
||||
rems = [min(r, r-c, key=abs) for r in remainders]
|
||||
if (rem:=sum(r*v for r,v in zip(rems,svars))+const%c).vmin//c==rem.vmax//c:
|
||||
if which is Ops.MOD: return rem - rem.vmin//c*c
|
||||
return sum((f-r)//c * v for f,r,v in zip(factors,rems,svars)) + (const-const%c+rem.vmin//c*c)//c
|
||||
|
||||
if gcd != 1: something_changed = True
|
||||
if not something_changed:
|
||||
|
||||
Reference in New Issue
Block a user