From 49c6dab74b28a426a188d5e4268f963ed9ec0764 Mon Sep 17 00:00:00 2001 From: Sieds Lykles <93992551+S-Lykles@users.noreply.github.com> Date: Thu, 5 Dec 2024 19:16:58 +0100 Subject: [PATCH] Add pattern for div mod recombine with gcd (#8061) Co-authored-by: chenyu --- test/unit/test_uop_symbolic.py | 2 ++ tinygrad/ops.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 168600f840..cdbf7ec1a4 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -444,6 +444,8 @@ class TestSymbolic(unittest.TestCase): def test_div_mod_recombine_with_gcd(self): b = Variable("b", 0, 100) + exp = (16 * b + 2) % 18 + ((16 * b + 2) // 18) * 18 + self.helper_test_variable(exp, 2, 1602, "((b*16)+2)") with self.assertRaises(AssertionError): self.helper_test_variable((30 * b + 1) % 18 + ((30 * b + 1) // 18) * 18, 1, 3001, "((b*30)+1)") diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 02daa2d4fa..c555ae1e4b 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1046,6 +1046,8 @@ symbolic_simple = PatternMatcher([ ((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x ((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed) (UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x + ((UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"), + lambda x,c1,c2,c3: x*c2 if c1.arg*c2.arg==c3.arg else None), # (x%c1)*c2+(x//c1)*c3 = x*c2 if c1*c2==c3 (UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c), (UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x), (UPat.var("x").maximum(UPat.var("x")), lambda x: x),