From ab7df42c78ea67b64ceae22a2bc3308af02e1602 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 19 Nov 2025 14:51:51 -0800 Subject: [PATCH] bring back fold_divmod_general with bugfix and test [pr] (#13369) * Revert "Revert "merge to fold_divmod_general [p] (#13359)"" This reverts commit 05ccc69248e438725f5f16da63d2b24721e9aa77. * Revert "Revert "actually merge to fold_divmod_general [pr] (#13363)"" This reverts commit 90e5752199319458a7c8c61a3b470fa230c26ce8. * Revert "Revert "add cache to fold_divmod_general (#13365)"" This reverts commit 8e17bd67915242d11dcc689c8940765cc7913e0f. * bring back fold_divmod_general with bugfix and test --- test/external/external_uop_gc.py | 2 + tinygrad/uop/divandmod.py | 160 ++++++++++++++----------------- 2 files changed, 73 insertions(+), 89 deletions(-) diff --git a/test/external/external_uop_gc.py b/test/external/external_uop_gc.py index 4327b69a56..f54c24e7c5 100644 --- a/test/external/external_uop_gc.py +++ b/test/external/external_uop_gc.py @@ -2,6 +2,7 @@ import gc from tinygrad import Tensor, UOp, Device, nn from tinygrad.engine.realize import method_cache, get_program from tinygrad.schedule.indexing import apply_movement_op +from tinygrad.uop.divandmod import fold_divmod_general from test.test_tiny import TestTiny def uops_allocated(): return sum([isinstance(x, UOp) for x in gc.get_objects()]) @@ -69,6 +70,7 @@ if __name__ == "__main__": # these caches will keep uops alive method_cache.clear() apply_movement_op.cache_clear() + fold_divmod_general.cache_clear() Tensor._device_seeds.clear() Tensor._device_rng_counters.clear() diff --git a/tinygrad/uop/divandmod.py b/tinygrad/uop/divandmod.py index 7bb3264d25..5779a1f0d1 100644 --- a/tinygrad/uop/divandmod.py +++ b/tinygrad/uop/divandmod.py @@ -1,100 +1,95 @@ +import functools from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp from tinygrad.dtype import dtypes from tinygrad.helpers import cdiv, cmod, CORRECT_DIVMOD_FOLDING, unwrap -def cancel_divmod(d: UOp, x: UOp, y: UOp) -> UOp|None: - # simple cancel div/mod case when the range of the numerator lies within a single denominator interval +# NOTE: this cache is only on index UOps and matches the cache in the old ShapeTracker in spirit +@functools.cache +def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None: + x, y = d.src + + # cancel_divmod: simple cancel div/mod case when the range of the numerator lies within a single denominator interval x_min, x_max, y_min, y_max = x.vmin, x.vmax, y.vmin, y.vmax assert isinstance(x_min, int) and isinstance(x_max, int) and isinstance(y_min, int) and isinstance(y_max, int) if y_min==y_max==0: raise ZeroDivisionError(f"{'Division' if d.op is Ops.IDIV else 'Mod'} by zero trying to rewrite {x.alu(d.op, y)}") if y_min*y_max > 0 and (q:=cdiv(x_min,y_min)) == cdiv(x_min,y_max) == cdiv(x_max,y_min) == cdiv(x_max,y_max): return x - q*y if d.op is Ops.MOD else d.const_like(q) - return None -def fold_binary_numerator(d: UOp, x: UOp, y: UOp) -> UOp|None: - # we can fold if the expression has only one non-constant term and this term can only take on two values - if ((c := y.arg) < 0): return None - x,const = x.pop_const() - terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in x.split_uop(Ops.ADD)]) - if len(terms)==1 and (v:=terms[0]).vmax-v.vmin == 1: - y1 = cmod(factors[0]*v.vmin+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmin+const, c) - y2 = cmod(factors[0]*v.vmax+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmax+const, c) - return (y2-y1)*(v-v.vmin) + y1 - return None + # split uops for the rest of the processing + x_peeled, const = x.pop_const() + uops_no_const = list(x_peeled.split_uop(Ops.ADD)) -def fold_divmod_congruence(d: UOp, x: UOp, y: UOp) -> UOp|None: - # within a mod we can freely subtract multiples of c, we use this to see if a is congruent to an expression whose vmin/vmax are between 0 and c - if (x.vmin<0 and CORRECT_DIVMOD_FOLDING) or ((c := y.arg) < 0): return None - x,const = x.pop_const() - terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in x.split_uop(Ops.ADD)]) - # a//c = (a-a%c)/c, if we can fold a%c, we can fold a//c - rems = [min((r:=f%c), r-c, key=abs) for f in factors] - if (rem:=sum(r*v for r,v in zip(rems,terms))+const%c).vmin//c!=rem.vmax//c: return None - if d.op is Ops.MOD: return rem - rem.vmin//c*c - return sum((f-r)//c * v for f,r,v in zip(factors,rems,terms)) + (const-const%c+rem.vmin//c*c)//c + # ** Constant Denominator Rules ** + # these rules strictly require y to be a scalar constant > 0 + if y.op is Ops.CONST and (c := y.arg) > 0: + # remove_nested_mod: remove nested mod in case the inner mod is a multiple of the outer mod, example: (a%4 + b)%2 -> (a+b)%2 + if d.op is Ops.MOD and x.vmin >= 0: + new_xs, changed = [], False + for u in uops_no_const: + if u.op is Ops.MOD and u.src[1].divides(c) is not None: + u = u.src[0] + changed = True + new_xs.append(u) + if changed and (new_x:=(UOp.sum(*new_xs) + const)).vmin >= 0: return new_x % y -def divide_by_gcd(d: UOp, x: UOp, y: UOp) -> UOp|None: - # x//y -> (x//gcd)//(y//gcd) or x%y -> gcd*(x//gcd)%(y//gcd) - gcd = UOp.gcd(*x.split_uop(Ops.ADD), y).simplify() - if gcd.op is Ops.CONST and gcd.arg==1: return None - ret = unwrap(x.divide_exact(gcd)).alu(d.op, unwrap(y.divide_exact(gcd))) - return ret*gcd if d.op is Ops.MOD else ret + # Shared decomposition for folding rules + decomp = [(u.divides(f:=u.const_factor()),f) for u in uops_no_const] + terms, factors = zip(*decomp) -def gcd_with_remainder(d: UOp, x: UOp, y: UOp): - # (gcd*x+r)//(gcd*d) -> (x+(r%d)//gcd)//d + r//(gcd*d) - # (gcd*x+r)%(gcd*d) -> gcd*(x+(r%d)//gcd)%d + r%gcd - # These only work for floordiv (and the corresponding remainder)! Thats why we check the sign of x,y and new_x - if ((c := y.arg) < 0) or x.vmin<0: return None - x_no_const, const = x.pop_const() - gcd = UOp.gcd(*x_no_const.split_uop(Ops.ADD), y).simplify() - assert gcd.op is Ops.CONST - if gcd.arg==1: return None - new_x = unwrap(x_no_const.divide_exact(gcd)).simplify() + (const%c)//gcd - if new_x.vmin<0: return None - ret = new_x.alu(d.op, x.ufix(c//gcd.arg)) - return ret*gcd + const%gcd.arg if d.op is Ops.MOD else ret+const//c + # fold_binary_numerator: fold if expression has one non-constant term that takes on two values + if len(terms)==1 and (v:=terms[0]).vmax-v.vmin == 1: + y1 = cmod(factors[0]*v.vmin+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmin+const, c) + y2 = cmod(factors[0]*v.vmax+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmax+const, c) + return (y2-y1)*(v-v.vmin) + y1 -def remove_nested_mod(m: UOp, x: UOp, y: UOp) -> UOp|None: - # remove nested mod in case the inner mod is a multiple of the outer mod - # example: (a%4 + b)%2 -> (a+b)%2 - if ((c := y.arg) < 0) or x.vmin<0: return None - new_xs = [] - something_changed = False - for u in x.split_uop(Ops.ADD): - if u.op is Ops.MOD: - if u.src[1].divides(c) is not None: - something_changed = True - u = u.src[0] - new_xs.append(u) - new_x: UOp = UOp.sum(*new_xs) - if something_changed and new_x.vmin>=0: return new_x % y - return None + # fold_divmod_congruence: fold if a is congruent to an expression whose range is between 0 and c + if not (x.vmin<0 and correct_divmod_folding): + rems = [min((r:=f%c), r-c, key=abs) for f in factors] + if (rem:=sum(r*v for r,v in zip(rems,terms))+const%c).vmin//c==rem.vmax//c: + if d.op is Ops.MOD: return rem - rem.vmin//c*c + return sum((f-r)//c * v for f,r,v in zip(factors,rems,terms)) + (const-const%c+rem.vmin//c*c)//c -def nest_div_by_smallest_factor(d: UOp, x: UOp, y: UOp) -> UOp|None: - # we try and nest the div and see if it allows the numerator to be simplified - if ((c := y.arg) < 0): return None - factors = [u.const_factor() for u in x.split_uop(Ops.ADD) if u.op not in (Ops.CONST, Ops.VCONST)] - div = min([y.arg]+[abs(f) for f in factors if abs(f) > 1 and (c%f)==0]) - newxs = fold_divmod_congruence(newx:=(x//div), x, y.const_like(div)) - if newxs is None: newxs = factor_remainder(newx, x, y.const_like(div)) - if div==y.arg or newxs is None or x.vmin<0 or newx.vmin<0: return None - return newxs//(c//div) + # gcd_with_remainder: factor out common gcd from numerator + # Note: this rule uses uops_no_const to exclude the additive constant from the GCD calculation + if x.vmin >= 0: + gcd = UOp.gcd(*uops_no_const, y).simplify() + if gcd.op is Ops.CONST and gcd.arg > 1: + new_x = unwrap(x_peeled.divide_exact(gcd)).simplify() + (const%c)//gcd.arg + if new_x.vmin >= 0: + ret = new_x.alu(d.op, x.ufix(c//gcd.arg)) + return ret*gcd + const%gcd.arg if d.op is Ops.MOD else ret+const//c -def factor_remainder(d: UOp, x: UOp, y: UOp) -> UOp|None: - # (d*x+y)//d -> x+y//d or (d*x+y)%d - # for mod we go further and take the remainder of all factors to reduce their size - # These only work for floordiv (and the corresponding remainder)! Thats why we check the sign of x,y and new_x + # nest_div_by_smallest_factor: try and nest the div and see if it allows the numerator to be simplified + if d.op is Ops.IDIV and x.vmin >= 0: + div = min([c] + [abs(f) for u, f in zip(uops_no_const, factors) if u.op not in (Ops.CONST, Ops.VCONST) and abs(f) > 1 and (c%f)==0]) + # NOTE: this is recursive! + if div < c and (newxs := fold_divmod_general(x//div, correct_divmod_folding)) is not None and newxs.vmin >= 0: + return newxs // (c // div) + + # ** Variable Denominator / Fallback Rules ** + # These rules apply to variables OR constants that failed the checks above. + # Reconstruct all uops including const for these checks. + all_uops = uops_no_const + ([x.const_like(const)] if const != 0 else []) + + # divide_by_gcd: x//y -> (x//gcd)//(y//gcd) + gcd = UOp.gcd(*all_uops, y).simplify() + if not (gcd.op is Ops.CONST and gcd.arg==1): + ret = unwrap(x.divide_exact(gcd)).alu(d.op, unwrap(y.divide_exact(gcd))) + return ret*gcd if d.op is Ops.MOD else ret + + # factor_remainder: (d*x+y)//d -> x+y//d if y.vmin<0 or x.vmin<0: return None quo, rem = [], [] - for u in x.split_uop(Ops.ADD): + for u in all_uops: if (q:=u.divide_exact(y)) is not None: quo.append(q) - # if this is mod and y is a const, we can make the remainder factor sm elif d.op is Ops.MOD and y.op is Ops.CONST and (c:=u.const_factor())%y.arg!=c: rem.append(u.divides(c)*(c%y.arg)) - quo.append(u.const_like(0)) # we append this so we can check if something changed + quo.append(u.const_like(0)) else: rem.append(u) + + if not quo: return None new_x = sum(rem)+x.const_like(0) - if len(quo)==0 or new_x.vmin<0: return None + if new_x.vmin<0: return None return new_x%y if d.op is Ops.MOD else new_x//y+sum(quo) div_and_mod_symbolic = PatternMatcher([ @@ -108,21 +103,8 @@ div_and_mod_symbolic = PatternMatcher([ ((UPat.var("x", dtypes.index)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False), lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None), - # NOTE: if you move this one down you get more uops in test/external/external_benchmark_schedule.py - (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.var("y"))), cancel_divmod), - - # ** 2. Slow Constant Denominator Rules (cvar) ** - # Prioritize these because they are mathematically stronger for constants - (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_binary_numerator), - (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_divmod_congruence), - (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), gcd_with_remainder), - (UPat(Ops.MOD, dtypes.index, name="m", src=(UPat.var("x"), UPat.cvar("y", vec=False))), remove_nested_mod), - (UPat(Ops.IDIV, dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), nest_div_by_smallest_factor), - - # ** 3. Slow Variable Denominator Rules (var) ** - # These catch cases like x//x or (a*b)//b - (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.var("y"))), divide_by_gcd), - (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.var("y"))), factor_remainder), + # ** 2. Slow Rules ** + (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d"), lambda d: fold_divmod_general(d, bool(CORRECT_DIVMOD_FOLDING))), # NOTE: these have to go at the bottom or TestSymbolicOps.test_var loops (UPat.var("x", dtypes.index) % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <= 0 else None),