variations of (x%c)+(x//c)*c (#15212)

put those into one function
This commit is contained in:
chenyu
2026-03-10 22:41:14 -04:00
committed by GitHub
parent a408d90f4f
commit be6b0bce1f
2 changed files with 26 additions and 18 deletions

View File

@@ -775,6 +775,13 @@ class TestSymbolic(unittest.TestCase):
# div nests: y//12 -> a//2, mod nests: y%12 -> (a%2)*6+b, recombine
self.helper_test_variable((y//12)*12 + y%12, 0, 43, "(b+a*6)")
def test_div_mod_recombine_in_additive_sum(self):
x = Variable("x", 0, 31)
y = Variable("y", 0, 5)
# recombine should work inside larger additive sums, not just in the two special y+... tree shapes
self.helper_test_variable((x//8)*4 + y + (x//2)%4, 0, 20, "(y+x//2)")
self.helper_test_variable(y + (x//8)*4 + (x//2)%4, 0, 20, "(y+x//2)")
def test_reshape_index_roundtrip(self):
# simulate reshape index decompose then recompose — the core pattern this enables
# (8,8) decomposed for (16,4): combined=r0*8+r1, div and mod by 4

View File

@@ -25,6 +25,23 @@ def fold_bitcast(root:UOp, c:UOp) -> UOp|None:
invalid_pat = UPat(Ops.CONST, arg=Invalid, name="i")
invalid_gate = UPat.var("cond").where(UPat.var("x"), invalid_pat)
def fold_add_divmod_recombine(x:UOp) -> UOp|None:
terms = list(x.split_uop(Ops.ADD))
for i,u in enumerate(terms):
if u.op is Ops.MOD and u.src[1].op is Ops.CONST: base, div, mul = u.src[0], u.src[1].arg, 1
elif u.op is Ops.MUL and u.src[1].op is Ops.CONST and (m:=u.src[0]).op is Ops.MOD and m.src[1].op is Ops.CONST:
base, div, mul = m.src[0], m.src[1].arg, u.src[1].arg
else: continue
for j,v in enumerate(terms):
if i == j: continue
if v.op is not Ops.MUL or v.src[1].op is not Ops.CONST or v.src[1].arg != div*mul: continue
q, exact = v.src[0], False
if q.op is Ops.IDIV and q.src[1].op is Ops.CONST and q.src[1].arg == div: exact = q.src[0] is base
if not exact and base.op is Ops.IDIV and base.src[1].op is Ops.CONST:
exact = q.op is Ops.IDIV and q.src[1].op is Ops.CONST and q.src[0] is base.src[0] and q.src[1].arg == base.src[1].arg*div
if exact: return functools.reduce(operator.add, (t for k,t in enumerate(terms) if k not in (i,j)), base*mul)
return None
# this needs to be before symbolic so that 0*something_that_might_be_invalid doesnt become 0
propagate_invalid = PatternMatcher([
# propagate invalid, push it past children
@@ -55,24 +72,8 @@ symbolic_simple = propagate_invalid + PatternMatcher([
(UPat.var("x") // 1, lambda x: x), # x//1 -> x
(UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x
((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed)
# variations of (x%c)+(x//c)*c = x TODO: add sorting to remove some variations
(UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
((UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c"),
lambda x,a,b,c: x//a if a.arg*c.arg==b.arg else None), # ((x//a)%c)+(x//a*c)*c = x//a. Note if a = 1 it degenerates to the one above
((UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c1")*UPat.cvar("c2")+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c3"),
lambda x,a,b,c1,c2,c3: x//a*c2 if c1.arg>0 and a.arg*c1.arg==b.arg and c1.arg*c2.arg==c3.arg else None),
((UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"),
lambda x,c1,c2,c3: x*c2 if c1.arg*c2.arg==c3.arg else None), # (x%c1)*c2+(x//c1)*c3 = x*c2 if c1*c2==c3
((UPat.var("y")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"))+UPat.var("x")%UPat.cvar("c"), lambda y,x,c: y+x),
((UPat.var("y")+UPat.var("x")%UPat.cvar("c"))+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda y,x,c: y+x),
((UPat.var("y")+(UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3"))+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"),
lambda y,x,c1,c2,c3: y+x*c2 if c1.arg*c2.arg==c3.arg else None),
((UPat.var("y")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"))+(UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3"),
lambda y,x,c1,c2,c3: y+x*c2 if c1.arg*c2.arg==c3.arg else None),
((UPat.var("y")+(UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c1")*UPat.cvar("c2"))+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c3"),
lambda y,x,a,b,c1,c2,c3: y+x//a*c2 if c1.arg>0 and a.arg*c1.arg==b.arg and c1.arg*c2.arg==c3.arg else None),
((UPat.var("y")+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c3"))+(UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c1")*UPat.cvar("c2"),
lambda y,x,a,b,c1,c2,c3: y+x//a*c2 if c1.arg>0 and a.arg*c1.arg==b.arg and c1.arg*c2.arg==c3.arg else None),
# variations of (x%c)+(x//c)*c = x
(UPat(Ops.ADD, dtype=dtypes.index, name="x"), fold_add_divmod_recombine),
(UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
(UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
(UPat(GroupOp.Idempotent, src=(UPat.var("x"), UPat.var("x"))), lambda x: x),