mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
factor out partial in SumNode div int (#3841)
* factor out partial in SumNode div int * div not rem * space
This commit is contained in:
@@ -389,7 +389,7 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
|
||||
unfactored_expr = Node.__floordiv__(expr_before_div, NumNode(-16), False)
|
||||
factored_expr = Node.__floordiv__(expr_before_div, NumNode(-16), True)
|
||||
self.assertEqual(unfactored_expr.render(), "(((lid*4)+1019+gid)//16)")
|
||||
self.assertEqual(factored_expr.render(), "((((1019+gid)//4)+lid)//4)")
|
||||
self.assertEqual(factored_expr.render(), "(((((3+gid)//4)+2+lid)//4)+63)")
|
||||
|
||||
def test_mod_node_max(self):
|
||||
i = Variable("i", 1, 128)
|
||||
@@ -467,9 +467,8 @@ class TestSymbolicRealWorld(unittest.TestCase):
|
||||
|
||||
idx = ((((1+lidx5)%16)*49)+(((262145+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+-13151129600+lidx3)
|
||||
print(idx.render())
|
||||
# TODO: 13,151,129,600 is out of int32 range.
|
||||
# assert idx.render() == \
|
||||
# "((((1+lidx5)%16)*49)+(((262145+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+-13151129600+lidx3)"
|
||||
# NOTE: this used to have 13,151,129,600 in the output which is out of int32 range.
|
||||
assert idx.render() == "((((1+lidx5)%16)*49)+(((1+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+2207744+lidx3)"
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -257,12 +257,15 @@ class SumNode(RedNode):
|
||||
divisor = 1
|
||||
for x in self.flat_components:
|
||||
if x.__class__ in (NumNode, MulNode):
|
||||
if x.b%b == 0: fully_divided.append(x//b)
|
||||
if x.b % b == 0: fully_divided.append(x // b)
|
||||
else:
|
||||
if x.__class__ is NumNode and (div := x.b // b):
|
||||
fully_divided.append(NumNode(div))
|
||||
x = NumNode(x.b - b * div)
|
||||
rest.append(x)
|
||||
if isinstance(x.b, int):
|
||||
_gcd = gcd(_gcd, x.b)
|
||||
if x.__class__ == MulNode and divisor == 1 and b%x.b == 0: divisor = x.b
|
||||
if x.__class__ == MulNode and divisor == 1 and b % x.b == 0: divisor = x.b
|
||||
else:
|
||||
_gcd = 1
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user