mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
More conditions for (x//c1+a)//c2 -> (x+a*c1)//(c1*c2) (#10834)
* add rule and test * typo --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -271,6 +271,10 @@ class TestSymbolic(unittest.TestCase):
|
||||
a = Variable("a", 0, 124)
|
||||
self.helper_test_variable(((a-10)//2+10)//2, 2, 33, "((((a+-10)//2)+10)//2)")
|
||||
|
||||
def test_div_const_div_wrong_sign_divisor(self):
|
||||
a = Variable("a", 0, 124)
|
||||
self.helper_test_variable(((a+10)//-2+10)//-4, -1, 14, "(((((a//2)*-1)+5)//4)*-1)")
|
||||
|
||||
def test_neg_mod(self):
|
||||
a = Variable("a", 0, 124)
|
||||
self.helper_test_variable((-a)%4, -3, 0, "((a%4)*-1)")
|
||||
|
||||
@@ -283,7 +283,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
# ** div **
|
||||
# div folding
|
||||
((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)
|
||||
if (x.vmin>=0 and a.vmin>=0) or (x.vmax<=0 and a.vmax<=0) else None), # (x//c+a)//d -> (x+a*c)//(c*d)
|
||||
if c.vmin>0 and d.vmin>0 and ((x.vmin>=0 and a.vmin>=0) or (x.vmax<=0 and a.vmax<=0)) else None), # (x//c+a)//d -> (x+a*c)//(c*d)
|
||||
(UPat.var("x", dtypes.sints) // UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.IDIV)),
|
||||
(UPat.var("x") // UPat.var("d"), lambda x,d: -(x//(-d)) if d.vmax <=0 else None),
|
||||
(UPat.var("x") // UPat.var("d"), lambda x,d: -((-x)//d) if x.vmax <=0 else None),
|
||||
|
||||
Reference in New Issue
Block a user