mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -775,6 +775,13 @@ class TestSymbolic(unittest.TestCase):
|
||||
# div nests: y//12 -> a//2, mod nests: y%12 -> (a%2)*6+b, recombine
|
||||
self.helper_test_variable((y//12)*12 + y%12, 0, 43, "(b+a*6)")
|
||||
|
||||
def test_div_mod_recombine_in_additive_sum(self):
|
||||
x = Variable("x", 0, 31)
|
||||
y = Variable("y", 0, 5)
|
||||
# recombine should work inside larger additive sums, not just in the two special y+... tree shapes
|
||||
self.helper_test_variable((x//8)*4 + y + (x//2)%4, 0, 20, "(y+x//2)")
|
||||
self.helper_test_variable(y + (x//8)*4 + (x//2)%4, 0, 20, "(y+x//2)")
|
||||
|
||||
def test_reshape_index_roundtrip(self):
|
||||
# simulate reshape index decompose then recompose — the core pattern this enables
|
||||
# (8,8) decomposed for (16,4): combined=r0*8+r1, div and mod by 4
|
||||
|
||||
@@ -25,6 +25,23 @@ def fold_bitcast(root:UOp, c:UOp) -> UOp|None:
|
||||
invalid_pat = UPat(Ops.CONST, arg=Invalid, name="i")
|
||||
invalid_gate = UPat.var("cond").where(UPat.var("x"), invalid_pat)
|
||||
|
||||
def fold_add_divmod_recombine(x:UOp) -> UOp|None:
|
||||
terms = list(x.split_uop(Ops.ADD))
|
||||
for i,u in enumerate(terms):
|
||||
if u.op is Ops.MOD and u.src[1].op is Ops.CONST: base, div, mul = u.src[0], u.src[1].arg, 1
|
||||
elif u.op is Ops.MUL and u.src[1].op is Ops.CONST and (m:=u.src[0]).op is Ops.MOD and m.src[1].op is Ops.CONST:
|
||||
base, div, mul = m.src[0], m.src[1].arg, u.src[1].arg
|
||||
else: continue
|
||||
for j,v in enumerate(terms):
|
||||
if i == j: continue
|
||||
if v.op is not Ops.MUL or v.src[1].op is not Ops.CONST or v.src[1].arg != div*mul: continue
|
||||
q, exact = v.src[0], False
|
||||
if q.op is Ops.IDIV and q.src[1].op is Ops.CONST and q.src[1].arg == div: exact = q.src[0] is base
|
||||
if not exact and base.op is Ops.IDIV and base.src[1].op is Ops.CONST:
|
||||
exact = q.op is Ops.IDIV and q.src[1].op is Ops.CONST and q.src[0] is base.src[0] and q.src[1].arg == base.src[1].arg*div
|
||||
if exact: return functools.reduce(operator.add, (t for k,t in enumerate(terms) if k not in (i,j)), base*mul)
|
||||
return None
|
||||
|
||||
# this needs to be before symbolic so that 0*something_that_might_be_invalid doesnt become 0
|
||||
propagate_invalid = PatternMatcher([
|
||||
# propagate invalid, push it past children
|
||||
@@ -55,24 +72,8 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
||||
(UPat.var("x") // 1, lambda x: x), # x//1 -> x
|
||||
(UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x
|
||||
((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed)
|
||||
# variations of (x%c)+(x//c)*c = x TODO: add sorting to remove some variations
|
||||
(UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
|
||||
((UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c"),
|
||||
lambda x,a,b,c: x//a if a.arg*c.arg==b.arg else None), # ((x//a)%c)+(x//a*c)*c = x//a. Note if a = 1 it degenerates to the one above
|
||||
((UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c1")*UPat.cvar("c2")+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c3"),
|
||||
lambda x,a,b,c1,c2,c3: x//a*c2 if c1.arg>0 and a.arg*c1.arg==b.arg and c1.arg*c2.arg==c3.arg else None),
|
||||
((UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"),
|
||||
lambda x,c1,c2,c3: x*c2 if c1.arg*c2.arg==c3.arg else None), # (x%c1)*c2+(x//c1)*c3 = x*c2 if c1*c2==c3
|
||||
((UPat.var("y")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"))+UPat.var("x")%UPat.cvar("c"), lambda y,x,c: y+x),
|
||||
((UPat.var("y")+UPat.var("x")%UPat.cvar("c"))+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda y,x,c: y+x),
|
||||
((UPat.var("y")+(UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3"))+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"),
|
||||
lambda y,x,c1,c2,c3: y+x*c2 if c1.arg*c2.arg==c3.arg else None),
|
||||
((UPat.var("y")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"))+(UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3"),
|
||||
lambda y,x,c1,c2,c3: y+x*c2 if c1.arg*c2.arg==c3.arg else None),
|
||||
((UPat.var("y")+(UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c1")*UPat.cvar("c2"))+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c3"),
|
||||
lambda y,x,a,b,c1,c2,c3: y+x//a*c2 if c1.arg>0 and a.arg*c1.arg==b.arg and c1.arg*c2.arg==c3.arg else None),
|
||||
((UPat.var("y")+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c3"))+(UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c1")*UPat.cvar("c2"),
|
||||
lambda y,x,a,b,c1,c2,c3: y+x//a*c2 if c1.arg>0 and a.arg*c1.arg==b.arg and c1.arg*c2.arg==c3.arg else None),
|
||||
# variations of (x%c)+(x//c)*c = x
|
||||
(UPat(Ops.ADD, dtype=dtypes.index, name="x"), fold_add_divmod_recombine),
|
||||
(UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
|
||||
(UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
|
||||
(UPat(GroupOp.Idempotent, src=(UPat.var("x"), UPat.var("x"))), lambda x: x),
|
||||
|
||||
Reference in New Issue
Block a user