mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 05:18:01 -05:00
UOp mul div simplification (#5637)
* UOp mul div simplification * != 0 is fine
This commit is contained in:
@@ -312,11 +312,9 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_mul_div(self):
|
||||
self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a")
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_mul_div_factor_mul(self):
|
||||
self.helper_test_variable((Variable("a", 0, 10)*8)//4, 0, 20, "(a*2)")
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_mul_div_factor_div(self):
|
||||
self.helper_test_variable((Variable("a", 0, 10)*4)//8, 0, 5, "(a//2)")
|
||||
|
||||
|
||||
@@ -246,8 +246,9 @@ constant_folder = PatternMatcher([
|
||||
(UOp.var("x") * UOp.cvar("c0") + UOp.var("x") * UOp.cvar("c1"), lambda x,c0,c1: x*exec_alu(BinaryOps.ADD, x.dtype, [c0.arg, c1.arg])),
|
||||
# (x*c0)+(y*c0) -> (x+y)*c0
|
||||
#((UOp.var("x") * UOp.cvar("c0")) + (UOp.var("y") * UOp.cvar("c0")), lambda x,y,c0: c0*(x+y)),
|
||||
# (x*c0)//c0 -> x
|
||||
((UOp.var("x") * UOp.cvar("c0")) // UOp.cvar("c0"), lambda x,c0: x if c0.arg != 0 else None),
|
||||
# mul div
|
||||
((UOp.var("x") * UOp.cvar("c0")) // UOp.cvar("c1"),
|
||||
lambda x,c0,c1: x*(c0.arg//gcd)//(c1.arg//gcd) if c1.arg!=0 and (gcd:=math.gcd(c0.arg,c1.arg))> 1 else None),
|
||||
# (x*x2)/x2 -> x
|
||||
((UOp.var("x") * UOp.var("x2")) / UOp.var("x2"), lambda x,x2: x),
|
||||
# (x//c0)//c1 -> x//(c0*c1)
|
||||
|
||||
Reference in New Issue
Block a user