diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 8c46440d6d..b6e81bfb52 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -312,11 +312,9 @@ class TestSymbolic(unittest.TestCase): def test_mul_div(self): self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a") - @unittest.expectedFailure def test_mul_div_factor_mul(self): self.helper_test_variable((Variable("a", 0, 10)*8)//4, 0, 20, "(a*2)") - @unittest.expectedFailure def test_mul_div_factor_div(self): self.helper_test_variable((Variable("a", 0, 10)*4)//8, 0, 5, "(a//2)") diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 0eb5fe03f1..f1deb59e5b 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -246,8 +246,9 @@ constant_folder = PatternMatcher([ (UOp.var("x") * UOp.cvar("c0") + UOp.var("x") * UOp.cvar("c1"), lambda x,c0,c1: x*exec_alu(BinaryOps.ADD, x.dtype, [c0.arg, c1.arg])), # (x*c0)+(y*c0) -> (x+y)*c0 #((UOp.var("x") * UOp.cvar("c0")) + (UOp.var("y") * UOp.cvar("c0")), lambda x,y,c0: c0*(x+y)), - # (x*c0)//c0 -> x - ((UOp.var("x") * UOp.cvar("c0")) // UOp.cvar("c0"), lambda x,c0: x if c0.arg != 0 else None), + # mul div + ((UOp.var("x") * UOp.cvar("c0")) // UOp.cvar("c1"), + lambda x,c0,c1: x*(c0.arg//gcd)//(c1.arg//gcd) if c1.arg!=0 and (gcd:=math.gcd(c0.arg,c1.arg))> 1 else None), # (x*x2)/x2 -> x ((UOp.var("x") * UOp.var("x2")) / UOp.var("x2"), lambda x,x2: x), # (x//c0)//c1 -> x//(c0*c1)