From c2ffcf68873aed8e85c9e3003f3c02fd3075a5bc Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 31 Jul 2024 16:24:25 -0400 Subject: [PATCH] remove the wrong mod UOp pattern (#5847) don't think we are hitting it because the stride construction, and it's wrong and not needed --- test/unit/test_symbolic.py | 1 + test/unit/test_uop_symbolic.py | 1 + tinygrad/codegen/uopgraph.py | 1 - 3 files changed, 2 insertions(+), 1 deletion(-) diff --git a/test/unit/test_symbolic.py b/test/unit/test_symbolic.py index 20819f1c4a..8b3f4f04a8 100644 --- a/test/unit/test_symbolic.py +++ b/test/unit/test_symbolic.py @@ -176,6 +176,7 @@ class TestSymbolic(unittest.TestCase): def test_mod_mod(self): self.helper_test_variable((Variable("a", 0, 31)%12)%4, 0, 3, "(a%4)") self.helper_test_variable(((4*Variable("a", 0, 31)) % 12) % 4, 0, 0, "0") + self.helper_test_variable(((5*Variable("a", 0, 31)) % 12) % 5, 0, 4, "(((a*5)%12)%5)") self.helper_test_variable((Variable("a", 0, 31) % 4) % 12, 0, 3, "(a%4)") def test_mul_mul(self): diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 959905278b..9923c4cc5e 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -236,6 +236,7 @@ class TestSymbolic(unittest.TestCase): def test_mod_mod(self): self.helper_test_variable((Variable("a", 0, 31)%12)%4, 0, 3, "(a%4)") self.helper_test_variable(((4*Variable("a", 0, 31)) % 12) % 4, 0, 0, "0") + self.helper_test_variable(((5*Variable("a", 0, 31)) % 12) % 5, 0, 4, {"(((a*5)%12)%5)", "(((5*a)%12)%5)"}) self.helper_test_variable((Variable("a", 0, 31) % 4) % 12, 0, 3, "(a%4)") def test_mul_mul(self): diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 7d182669b7..017f245539 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -240,7 +240,6 @@ constant_folder = PatternMatcher([ x*(c0.arg//g)//(c1.arg//g) if c0.arg > 0 and c1.arg > 0 and (g:=math.gcd(c0.arg,c1.arg)) > 1 and g > x2.vmax.arg and x2.vmin.arg >= 0 else None), # mod mod ((NOp.var('x') % NOp.cvar('c0')) % NOp.cvar('c1'), lambda x,c0,c1: x % c0 if 0 < c0.arg < c1.arg else x % c1 if c0.arg % c1.arg == 0 else None), - (((NOp.var('x') * NOp.cvar('c0')) % NOp.cvar('c1')) % NOp.cvar('c0'), lambda x,c0,c1: x.const(0)), # -(x+y) -> -x + -y #(-(NOp.var("x") + NOp.var("y")), lambda x,y: (-x)+(-y)), # (x*c0)+(x*c1) -> x*(c0+c1)