mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
add cache to fold_divmod_general (#13365)
This commit is contained in:
2
test/external/external_uop_gc.py
vendored
2
test/external/external_uop_gc.py
vendored
@@ -2,6 +2,7 @@ import gc
|
||||
from tinygrad import Tensor, UOp, Device, nn
|
||||
from tinygrad.engine.realize import method_cache, get_program
|
||||
from tinygrad.schedule.indexing import apply_movement_op
|
||||
from tinygrad.uop.divandmod import fold_divmod_general
|
||||
from test.test_tiny import TestTiny
|
||||
|
||||
def uops_allocated(): return sum([isinstance(x, UOp) for x in gc.get_objects()])
|
||||
@@ -69,6 +70,7 @@ if __name__ == "__main__":
|
||||
# these caches will keep uops alive
|
||||
method_cache.clear()
|
||||
apply_movement_op.cache_clear()
|
||||
fold_divmod_general.cache_clear()
|
||||
Tensor._device_seeds.clear()
|
||||
Tensor._device_rng_counters.clear()
|
||||
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import functools
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import cdiv, cmod, CORRECT_DIVMOD_FOLDING, unwrap
|
||||
|
||||
def fold_divmod_general(d: UOp) -> UOp|None:
|
||||
# NOTE: this cache is only on index UOps and matches the cache in the old ShapeTracker in spirit
|
||||
@functools.cache
|
||||
def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None:
|
||||
x, y = d.src
|
||||
|
||||
# cancel_divmod: simple cancel div/mod case when the range of the numerator lies within a single denominator interval
|
||||
@@ -40,7 +43,7 @@ def fold_divmod_general(d: UOp) -> UOp|None:
|
||||
return (y2-y1)*(v-v.vmin) + y1
|
||||
|
||||
# fold_divmod_congruence: fold if a is congruent to an expression whose range is between 0 and c
|
||||
if not (x.vmin<0 and CORRECT_DIVMOD_FOLDING):
|
||||
if not (x.vmin<0 and correct_divmod_folding):
|
||||
rems = [min((r:=f%c), r-c, key=abs) for f in factors]
|
||||
if (rem:=sum(r*v for r,v in zip(rems,terms))+const%c).vmin//c==rem.vmax//c:
|
||||
if d.op is Ops.MOD: return rem - rem.vmin//c*c
|
||||
@@ -60,7 +63,7 @@ def fold_divmod_general(d: UOp) -> UOp|None:
|
||||
if d.op is Ops.IDIV and x.vmin >= 0:
|
||||
div = min([c] + [abs(f) for u, f in zip(uops_no_const, factors) if u.op not in (Ops.CONST, Ops.VCONST) and abs(f) > 1 and (c%f)==0])
|
||||
# NOTE: this is recursive!
|
||||
if div < c and (newxs := fold_divmod_general(x//div)) is not None and newxs.vmin >= 0:
|
||||
if div < c and (newxs := fold_divmod_general(x//div, correct_divmod_folding)) is not None and newxs.vmin >= 0:
|
||||
return newxs // (c // div)
|
||||
|
||||
# ** Variable Denominator / Fallback Rules **
|
||||
@@ -101,7 +104,7 @@ div_and_mod_symbolic = PatternMatcher([
|
||||
lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None),
|
||||
|
||||
# ** 2. Slow Rules **
|
||||
(UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d"), fold_divmod_general),
|
||||
(UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d"), lambda d: fold_divmod_general(d, bool(CORRECT_DIVMOD_FOLDING))),
|
||||
|
||||
# NOTE: these have to go at the bottom or TestSymbolicOps.test_var loops
|
||||
(UPat.var("x", dtypes.index) % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <= 0 else None),
|
||||
|
||||
Reference in New Issue
Block a user