Simplify symbolic.SumNode.__floordiv__ logic (#1220)

This commit is contained in:
chenyu
2023-07-12 15:54:12 -04:00
committed by GitHub
parent a9a1df785f
commit 32be39554c
2 changed files with 18 additions and 19 deletions

View File

@@ -119,6 +119,10 @@ class TestSymbolic(unittest.TestCase):
def test_sum_div_some_factor(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "(((a*5)//2)+(b*2))")
def test_sum_div_some_partial_factor(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)")
self.helper_test_variable(Variable.sum([Variable.num(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)")
def test_sum_div_no_factor(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)")

View File

@@ -186,28 +186,23 @@ class SumNode(RedNode):
def __floordiv__(self, b: int, factoring_allowed=True):
if b == 1: return self
if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
factors: List[Node] = []
nofactor_mul: List[Node] = []
nofactor_nonmul: List[Node] = []
fully_divided: List[Node] = []
rest: List[Node] = []
_gcd = b
divisor = 1
for x in self.flat_components:
if x.__class__ is NumNode and x.b%b == 0: factors.append(x)
elif x.__class__ is MulNode: factors.append(x) if x.b%b == 0 else nofactor_mul.append(x)
else: nofactor_nonmul.append(x)
if factors: # factor out largest possible gcd
factor_term = [x.a * x.b//b if isinstance(x, MulNode) else NumNode(x.b//b) for x in factors]
if nofactor_mul and not nofactor_nonmul:
gcds = [gcd(x.b, b) for x in nofactor_mul]
if (t := min(gcds)) > 1 and all(x.b%t == 0 for x in nofactor_mul):
nofactor_term = [Node.sum([x.a * x.b//t for x in nofactor_mul if isinstance(x, MulNode)])//(b//t)] # mypy wants the isinstance
if x.__class__ in (NumNode, MulNode):
if x.b%b == 0: fully_divided.append(x//b)
else:
nofactor_term = [Node.sum(nofactor_mul)//b] if nofactor_mul else []
rest.append(x)
_gcd = gcd(_gcd, x.b)
if x.__class__ == MulNode and divisor == 1 and b%x.b == 0: divisor = x.b
else:
nofactor_term = [Node.sum(nofactor_mul+nofactor_nonmul)//b] if nofactor_mul + nofactor_nonmul else []
return Node.sum(factor_term + nofactor_term)
for m in nofactor_mul:
if m.b > 1 and b%m.b == 0: return (self//m.b)//(b//m.b)
return Node.__floordiv__(self, b, factoring_allowed)
rest.append(x)
_gcd = 1
if _gcd > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(_gcd) // (b//_gcd)
if divisor > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(divisor) // (b//divisor)
return Node.sum(fully_divided) + Node.__floordiv__(Node.sum(rest), b)
def __mod__(self, b: int):
new_nodes: List[Node] = []