From be6b0bce1f000c36bb07457d58fc0960dbaede17 Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 10 Mar 2026 22:41:14 -0400 Subject: [PATCH] variations of (x%c)+(x//c)*c (#15212) put those into one function --- test/null/test_uop_symbolic.py | 7 +++++++ tinygrad/uop/symbolic.py | 37 +++++++++++++++++----------------- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/test/null/test_uop_symbolic.py b/test/null/test_uop_symbolic.py index 8465623514..369b7542fa 100644 --- a/test/null/test_uop_symbolic.py +++ b/test/null/test_uop_symbolic.py @@ -775,6 +775,13 @@ class TestSymbolic(unittest.TestCase): # div nests: y//12 -> a//2, mod nests: y%12 -> (a%2)*6+b, recombine self.helper_test_variable((y//12)*12 + y%12, 0, 43, "(b+a*6)") + def test_div_mod_recombine_in_additive_sum(self): + x = Variable("x", 0, 31) + y = Variable("y", 0, 5) + # recombine should work inside larger additive sums, not just in the two special y+... tree shapes + self.helper_test_variable((x//8)*4 + y + (x//2)%4, 0, 20, "(y+x//2)") + self.helper_test_variable(y + (x//8)*4 + (x//2)%4, 0, 20, "(y+x//2)") + def test_reshape_index_roundtrip(self): # simulate reshape index decompose then recompose — the core pattern this enables # (8,8) decomposed for (16,4): combined=r0*8+r1, div and mod by 4 diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index a394cc5521..5ed39f4155 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -25,6 +25,23 @@ def fold_bitcast(root:UOp, c:UOp) -> UOp|None: invalid_pat = UPat(Ops.CONST, arg=Invalid, name="i") invalid_gate = UPat.var("cond").where(UPat.var("x"), invalid_pat) +def fold_add_divmod_recombine(x:UOp) -> UOp|None: + terms = list(x.split_uop(Ops.ADD)) + for i,u in enumerate(terms): + if u.op is Ops.MOD and u.src[1].op is Ops.CONST: base, div, mul = u.src[0], u.src[1].arg, 1 + elif u.op is Ops.MUL and u.src[1].op is Ops.CONST and (m:=u.src[0]).op is Ops.MOD and m.src[1].op is Ops.CONST: + base, div, mul = m.src[0], m.src[1].arg, u.src[1].arg + else: continue + for j,v in enumerate(terms): + if i == j: continue + if v.op is not Ops.MUL or v.src[1].op is not Ops.CONST or v.src[1].arg != div*mul: continue + q, exact = v.src[0], False + if q.op is Ops.IDIV and q.src[1].op is Ops.CONST and q.src[1].arg == div: exact = q.src[0] is base + if not exact and base.op is Ops.IDIV and base.src[1].op is Ops.CONST: + exact = q.op is Ops.IDIV and q.src[1].op is Ops.CONST and q.src[0] is base.src[0] and q.src[1].arg == base.src[1].arg*div + if exact: return functools.reduce(operator.add, (t for k,t in enumerate(terms) if k not in (i,j)), base*mul) + return None + # this needs to be before symbolic so that 0*something_that_might_be_invalid doesnt become 0 propagate_invalid = PatternMatcher([ # propagate invalid, push it past children @@ -55,24 +72,8 @@ symbolic_simple = propagate_invalid + PatternMatcher([ (UPat.var("x") // 1, lambda x: x), # x//1 -> x (UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x ((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed) - # variations of (x%c)+(x//c)*c = x TODO: add sorting to remove some variations - (UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x - ((UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c"), - lambda x,a,b,c: x//a if a.arg*c.arg==b.arg else None), # ((x//a)%c)+(x//a*c)*c = x//a. Note if a = 1 it degenerates to the one above - ((UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c1")*UPat.cvar("c2")+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c3"), - lambda x,a,b,c1,c2,c3: x//a*c2 if c1.arg>0 and a.arg*c1.arg==b.arg and c1.arg*c2.arg==c3.arg else None), - ((UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"), - lambda x,c1,c2,c3: x*c2 if c1.arg*c2.arg==c3.arg else None), # (x%c1)*c2+(x//c1)*c3 = x*c2 if c1*c2==c3 - ((UPat.var("y")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"))+UPat.var("x")%UPat.cvar("c"), lambda y,x,c: y+x), - ((UPat.var("y")+UPat.var("x")%UPat.cvar("c"))+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda y,x,c: y+x), - ((UPat.var("y")+(UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3"))+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"), - lambda y,x,c1,c2,c3: y+x*c2 if c1.arg*c2.arg==c3.arg else None), - ((UPat.var("y")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"))+(UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3"), - lambda y,x,c1,c2,c3: y+x*c2 if c1.arg*c2.arg==c3.arg else None), - ((UPat.var("y")+(UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c1")*UPat.cvar("c2"))+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c3"), - lambda y,x,a,b,c1,c2,c3: y+x//a*c2 if c1.arg>0 and a.arg*c1.arg==b.arg and c1.arg*c2.arg==c3.arg else None), - ((UPat.var("y")+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c3"))+(UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c1")*UPat.cvar("c2"), - lambda y,x,a,b,c1,c2,c3: y+x//a*c2 if c1.arg>0 and a.arg*c1.arg==b.arg and c1.arg*c2.arg==c3.arg else None), + # variations of (x%c)+(x//c)*c = x + (UPat(Ops.ADD, dtype=dtypes.index, name="x"), fold_add_divmod_recombine), (UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c), (UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x), (UPat(GroupOp.Idempotent, src=(UPat.var("x"), UPat.var("x"))), lambda x: x),