From bb8cf948f2308613790c81bca872700cc78258cb Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 6 Nov 2025 18:53:28 -0500 Subject: [PATCH] variation of (x%c)+(x//c)*c = x (#13135) when x is in the form of y//b, the idiv term might have combined --- test/unit/test_uop_symbolic.py | 4 ++++ tinygrad/uop/symbolic.py | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 087d848690..b958b62a45 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -643,6 +643,10 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(lidx+(gidx//4)*8+2*(gidx%4), 0, 372, "(lidx+(gidx*2))") self.helper_test_variable(lidx+2*(gidx%4)+(gidx//4)*8, 0, 372, "(lidx+(gidx*2))") + def test_div_mod_recombine_partial(self): + gidx = Variable("gidx", 0, 15) + self.helper_test_variable((gidx//2)%4+(gidx//8)*4, 0, 7, "gidx//2") + def test_div_mod_recombine_folded_mod(self): a = Variable("a", 0, 2) b = Variable("b", 0, 100) diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index 13a7156211..7d445a2d00 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -48,8 +48,10 @@ symbolic_simple = propagate_invalid + PatternMatcher([ (UPat.var("x") // 1, lambda x: x), # x//1 -> x (UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x ((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed) - # 4 variations of (x%c)+(x//c)*c = x TODO: add sorting to remove some variations + # variations of (x%c)+(x//c)*c = x TODO: add sorting to remove some variations (UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x + ((UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c"), + lambda x,a,b,c: x//a if a.arg*c.arg==b.arg else None), # ((x//a)%c)+(x//a*c)*c = x//a. Note if a = 1 it degenerates to the one above ((UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"), lambda x,c1,c2,c3: x*c2 if c1.arg*c2.arg==c3.arg else None), # (x%c1)*c2+(x//c1)*c3 = x*c2 if c1*c2==c3 ((UPat.var("y")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"))+UPat.var("x")%UPat.cvar("c"), lambda y,x,c: y+x),