diff --git a/test/unit/test_symbolic.py b/test/unit/test_symbolic.py index a6423cba01..447cd31815 100644 --- a/test/unit/test_symbolic.py +++ b/test/unit/test_symbolic.py @@ -389,7 +389,7 @@ class TestSymbolicSymbolicOps(unittest.TestCase): unfactored_expr = Node.__floordiv__(expr_before_div, NumNode(-16), False) factored_expr = Node.__floordiv__(expr_before_div, NumNode(-16), True) self.assertEqual(unfactored_expr.render(), "(((lid*4)+1019+gid)//16)") - self.assertEqual(factored_expr.render(), "((((1019+gid)//4)+lid)//4)") + self.assertEqual(factored_expr.render(), "(((((3+gid)//4)+2+lid)//4)+63)") def test_mod_node_max(self): i = Variable("i", 1, 128) @@ -467,9 +467,8 @@ class TestSymbolicRealWorld(unittest.TestCase): idx = ((((1+lidx5)%16)*49)+(((262145+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+-13151129600+lidx3) print(idx.render()) - # TODO: 13,151,129,600 is out of int32 range. - # assert idx.render() == \ - # "((((1+lidx5)%16)*49)+(((262145+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+-13151129600+lidx3)" + # NOTE: this used to have 13,151,129,600 in the output which is out of int32 range. + assert idx.render() == "((((1+lidx5)%16)*49)+(((1+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+2207744+lidx3)" if __name__ == '__main__': unittest.main() diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index d396e7700f..4973fa8cbd 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -257,12 +257,15 @@ class SumNode(RedNode): divisor = 1 for x in self.flat_components: if x.__class__ in (NumNode, MulNode): - if x.b%b == 0: fully_divided.append(x//b) + if x.b % b == 0: fully_divided.append(x // b) else: + if x.__class__ is NumNode and (div := x.b // b): + fully_divided.append(NumNode(div)) + x = NumNode(x.b - b * div) rest.append(x) if isinstance(x.b, int): _gcd = gcd(_gcd, x.b) - if x.__class__ == MulNode and divisor == 1 and b%x.b == 0: divisor = x.b + if x.__class__ == MulNode and divisor == 1 and b % x.b == 0: divisor = x.b else: _gcd = 1 else: