mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
simplify c0*x<c1 for negative int c0,c1 (#6431)
* simplify c0*x<c1 for negative int c0,c1 * fine if rhs is zero
This commit is contained in:
@@ -198,6 +198,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_mul_lt(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,13), 0, 1, "(a<4)")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,16), 0, 1, "(a<4)")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*(-2),0), 0, 1, "((a*-1)<0)")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,12), 0, 1, "((a*-1)<-2)")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,13), 0, 1, "((a*-1)<-3)")
|
||||
|
||||
|
||||
@@ -254,8 +254,9 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_mul_lt(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,13), 0, 1, "(a<4)")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,16), 0, 1, "(a<4)")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,12), 0, 1, {"((a*-1)<-2)", "((a*(-4))<(-11))"})
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,13), 0, 1, {"((a*-1)<-3)", "((a*(-4))<(-12))"})
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*(-2),0), 0, 1, {"((a*-1)<0)", "((a*(-1))<0)"})
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,12), 0, 1, {"((a*-1)<-2)", "((a*(-1))<(-2))"})
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,13), 0, 1, {"((a*-1)<-3)", "((a*(-1))<(-3))"})
|
||||
|
||||
def test_div_div(self):
|
||||
self.helper_test_variable((Variable("a", 0, 1800)//10)//9, 0, 20, "(a//90)")
|
||||
|
||||
@@ -292,6 +292,9 @@ constant_folder = PatternMatcher([
|
||||
# c0*x<c1 for positive int c0,c1
|
||||
((NOp.cvar('c0')*NOp.var('x')).lt(NOp.cvar('c1')),
|
||||
lambda x,c0,c1: x.lt(math.ceil(c1.arg/c0.arg)) if dtypes.is_int(x.dtype) and c0.arg > 0 and c1.arg > 0 else None),
|
||||
# c0*x<c1 for negative int c0 and non-positive c1
|
||||
((NOp.cvar('c0')*NOp.var('x')).lt(NOp.cvar('c1')),
|
||||
lambda x,c0,c1: (-x).lt(-(math.floor(-c1.arg/-c0.arg))) if dtypes.is_int(x.dtype) and c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
|
||||
# mul add lt
|
||||
(((NOp.cvar('c0')*NOp.var('x'))+NOp.var('x2')).lt(NOp.cvar('c1')),
|
||||
lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None),
|
||||
|
||||
Reference in New Issue
Block a user