more lt folding using gcd (#6469)

This commit is contained in:
chenyu
2024-09-11 02:09:35 -04:00
committed by GitHub
parent dfe1db1cff
commit d9d1ae7248
2 changed files with 3 additions and 4 deletions

View File

@@ -297,12 +297,10 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(create_lt_node(Variable("a", 0, 6)*6+Variable("b", 0, 6)*6, 8), 0, 1,
"(((a*3)+(b*3))<4)")
@unittest.expectedFailure
def test_lt_sum_factor_rhs_partial(self):
self.helper_test_variable(create_lt_node(Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8, 4), 0, 1,
"(((a*3)+(b*2)+(c*4))<2)")
@unittest.expectedFailure
def test_lt_sum_factor_rhs_all(self):
self.helper_test_variable(create_lt_node(Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8, 2), 0, 1,
"(((a*3)+(b*2)+(c*4))<1)")

View File

@@ -151,7 +151,8 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]:
return quo if rem is None else cast(UOp, div_folding(rem, div))//(c//div)+quo
def lt_folding(x:UOp, c:int) -> Optional[UOp]:
return newx.src[0].lt(newx.src[1]) if (newx:=div_folding(x,c)) is not None and newx.op is UOps.ALU and newx.arg is BinaryOps.IDIV else None
if (newx:=div_folding(x,c)) is not None and newx.op is UOps.ALU and newx.arg is BinaryOps.IDIV: return newx.src[0].lt(newx.src[1])
return cast(UOp, x.divides(g)).lt(c//g) if ((g:=math.gcd(x.const_factor(), c)) > 1) else None
def fold_unrolled_divs(divs:UOp, c:UOp):
# div pattern in unrolled arange
@@ -314,7 +315,7 @@ constant_folder = PatternMatcher([
# mul add lt
(((NOp.cvar('c0')*NOp.var('x'))+NOp.var('x2')).lt(NOp.cvar('c1')),
lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None),
# generic lt folding (using div)
# generic lt folding
(NOp.var('x').lt(NOp.cvar('c')),
lambda x,c: lt_folding(x, c.arg) if 0 < c.arg and dtypes.is_int(x.dtype) and not dtypes.is_unsigned(x.dtype) else None),
# ** div **