From 97b116bb1d67d950a34ad75f91d64ca46833ba82 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 22 Jul 2024 16:14:12 -0400 Subject: [PATCH] UOp mul div simplification (#5637) * UOp mul div simplification * != 0 is fine --- test/unit/test_uop_symbolic.py | 2 -- tinygrad/codegen/uopgraph.py | 5 +++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 8c46440d6d..b6e81bfb52 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -312,11 +312,9 @@ class TestSymbolic(unittest.TestCase): def test_mul_div(self): self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a") - @unittest.expectedFailure def test_mul_div_factor_mul(self): self.helper_test_variable((Variable("a", 0, 10)*8)//4, 0, 20, "(a*2)") - @unittest.expectedFailure def test_mul_div_factor_div(self): self.helper_test_variable((Variable("a", 0, 10)*4)//8, 0, 5, "(a//2)") diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 0eb5fe03f1..f1deb59e5b 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -246,8 +246,9 @@ constant_folder = PatternMatcher([ (UOp.var("x") * UOp.cvar("c0") + UOp.var("x") * UOp.cvar("c1"), lambda x,c0,c1: x*exec_alu(BinaryOps.ADD, x.dtype, [c0.arg, c1.arg])), # (x*c0)+(y*c0) -> (x+y)*c0 #((UOp.var("x") * UOp.cvar("c0")) + (UOp.var("y") * UOp.cvar("c0")), lambda x,y,c0: c0*(x+y)), - # (x*c0)//c0 -> x - ((UOp.var("x") * UOp.cvar("c0")) // UOp.cvar("c0"), lambda x,c0: x if c0.arg != 0 else None), + # mul div + ((UOp.var("x") * UOp.cvar("c0")) // UOp.cvar("c1"), + lambda x,c0,c1: x*(c0.arg//gcd)//(c1.arg//gcd) if c1.arg!=0 and (gcd:=math.gcd(c0.arg,c1.arg))> 1 else None), # (x*x2)/x2 -> x ((UOp.var("x") * UOp.var("x2")) / UOp.var("x2"), lambda x,x2: x), # (x//c0)//c1 -> x//(c0*c1)