mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
more lt folding using gcd (#6469)
This commit is contained in:
@@ -297,12 +297,10 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6)*6+Variable("b", 0, 6)*6, 8), 0, 1,
|
||||
"(((a*3)+(b*3))<4)")
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_lt_sum_factor_rhs_partial(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8, 4), 0, 1,
|
||||
"(((a*3)+(b*2)+(c*4))<2)")
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_lt_sum_factor_rhs_all(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8, 2), 0, 1,
|
||||
"(((a*3)+(b*2)+(c*4))<1)")
|
||||
|
||||
@@ -151,7 +151,8 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
return quo if rem is None else cast(UOp, div_folding(rem, div))//(c//div)+quo
|
||||
|
||||
def lt_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
return newx.src[0].lt(newx.src[1]) if (newx:=div_folding(x,c)) is not None and newx.op is UOps.ALU and newx.arg is BinaryOps.IDIV else None
|
||||
if (newx:=div_folding(x,c)) is not None and newx.op is UOps.ALU and newx.arg is BinaryOps.IDIV: return newx.src[0].lt(newx.src[1])
|
||||
return cast(UOp, x.divides(g)).lt(c//g) if ((g:=math.gcd(x.const_factor(), c)) > 1) else None
|
||||
|
||||
def fold_unrolled_divs(divs:UOp, c:UOp):
|
||||
# div pattern in unrolled arange
|
||||
@@ -314,7 +315,7 @@ constant_folder = PatternMatcher([
|
||||
# mul add lt
|
||||
(((NOp.cvar('c0')*NOp.var('x'))+NOp.var('x2')).lt(NOp.cvar('c1')),
|
||||
lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None),
|
||||
# generic lt folding (using div)
|
||||
# generic lt folding
|
||||
(NOp.var('x').lt(NOp.cvar('c')),
|
||||
lambda x,c: lt_folding(x, c.arg) if 0 < c.arg and dtypes.is_int(x.dtype) and not dtypes.is_unsigned(x.dtype) else None),
|
||||
# ** div **
|
||||
|
||||
Reference in New Issue
Block a user