factor out partial in SumNode div int (#3841)

* factor out partial in SumNode div int

* div not rem

* space
This commit is contained in:
chenyu
2024-03-20 16:34:33 -04:00
committed by GitHub
parent 8cb5215885
commit 519336cfea
2 changed files with 8 additions and 6 deletions

View File

@@ -389,7 +389,7 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
unfactored_expr = Node.__floordiv__(expr_before_div, NumNode(-16), False)
factored_expr = Node.__floordiv__(expr_before_div, NumNode(-16), True)
self.assertEqual(unfactored_expr.render(), "(((lid*4)+1019+gid)//16)")
self.assertEqual(factored_expr.render(), "((((1019+gid)//4)+lid)//4)")
self.assertEqual(factored_expr.render(), "(((((3+gid)//4)+2+lid)//4)+63)")
def test_mod_node_max(self):
i = Variable("i", 1, 128)
@@ -467,9 +467,8 @@ class TestSymbolicRealWorld(unittest.TestCase):
idx = ((((1+lidx5)%16)*49)+(((262145+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+-13151129600+lidx3)
print(idx.render())
# TODO: 13,151,129,600 is out of int32 range.
# assert idx.render() == \
# "((((1+lidx5)%16)*49)+(((262145+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+-13151129600+lidx3)"
# NOTE: this used to have 13,151,129,600 in the output which is out of int32 range.
assert idx.render() == "((((1+lidx5)%16)*49)+(((1+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+2207744+lidx3)"
if __name__ == '__main__':
unittest.main()

View File

@@ -257,12 +257,15 @@ class SumNode(RedNode):
divisor = 1
for x in self.flat_components:
if x.__class__ in (NumNode, MulNode):
if x.b%b == 0: fully_divided.append(x//b)
if x.b % b == 0: fully_divided.append(x // b)
else:
if x.__class__ is NumNode and (div := x.b // b):
fully_divided.append(NumNode(div))
x = NumNode(x.b - b * div)
rest.append(x)
if isinstance(x.b, int):
_gcd = gcd(_gcd, x.b)
if x.__class__ == MulNode and divisor == 1 and b%x.b == 0: divisor = x.b
if x.__class__ == MulNode and divisor == 1 and b % x.b == 0: divisor = x.b
else:
_gcd = 1
else: