add cache to fold_divmod_general (#13365)

This commit is contained in:
George Hotz
2025-11-19 13:49:18 -08:00
committed by GitHub
parent 3d82b83cec
commit b5309a5043
2 changed files with 9 additions and 4 deletions

View File

@@ -2,6 +2,7 @@ import gc
from tinygrad import Tensor, UOp, Device, nn
from tinygrad.engine.realize import method_cache, get_program
from tinygrad.schedule.indexing import apply_movement_op
from tinygrad.uop.divandmod import fold_divmod_general
from test.test_tiny import TestTiny
def uops_allocated(): return sum([isinstance(x, UOp) for x in gc.get_objects()])
@@ -69,6 +70,7 @@ if __name__ == "__main__":
# these caches will keep uops alive
method_cache.clear()
apply_movement_op.cache_clear()
fold_divmod_general.cache_clear()
Tensor._device_seeds.clear()
Tensor._device_rng_counters.clear()

View File

@@ -1,8 +1,11 @@
import functools
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp
from tinygrad.dtype import dtypes
from tinygrad.helpers import cdiv, cmod, CORRECT_DIVMOD_FOLDING, unwrap
def fold_divmod_general(d: UOp) -> UOp|None:
# NOTE: this cache is only on index UOps and matches the cache in the old ShapeTracker in spirit
@functools.cache
def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None:
x, y = d.src
# cancel_divmod: simple cancel div/mod case when the range of the numerator lies within a single denominator interval
@@ -40,7 +43,7 @@ def fold_divmod_general(d: UOp) -> UOp|None:
return (y2-y1)*(v-v.vmin) + y1
# fold_divmod_congruence: fold if a is congruent to an expression whose range is between 0 and c
if not (x.vmin<0 and CORRECT_DIVMOD_FOLDING):
if not (x.vmin<0 and correct_divmod_folding):
rems = [min((r:=f%c), r-c, key=abs) for f in factors]
if (rem:=sum(r*v for r,v in zip(rems,terms))+const%c).vmin//c==rem.vmax//c:
if d.op is Ops.MOD: return rem - rem.vmin//c*c
@@ -60,7 +63,7 @@ def fold_divmod_general(d: UOp) -> UOp|None:
if d.op is Ops.IDIV and x.vmin >= 0:
div = min([c] + [abs(f) for u, f in zip(uops_no_const, factors) if u.op not in (Ops.CONST, Ops.VCONST) and abs(f) > 1 and (c%f)==0])
# NOTE: this is recursive!
if div < c and (newxs := fold_divmod_general(x//div)) is not None and newxs.vmin >= 0:
if div < c and (newxs := fold_divmod_general(x//div, correct_divmod_folding)) is not None and newxs.vmin >= 0:
return newxs // (c // div)
# ** Variable Denominator / Fallback Rules **
@@ -101,7 +104,7 @@ div_and_mod_symbolic = PatternMatcher([
lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None),
# ** 2. Slow Rules **
(UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d"), fold_divmod_general),
(UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d"), lambda d: fold_divmod_general(d, bool(CORRECT_DIVMOD_FOLDING))),
# NOTE: these have to go at the bottom or TestSymbolicOps.test_var loops
(UPat.var("x", dtypes.index) % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <= 0 else None),