diff --git a/test/unit/test_symbolic.py b/test/unit/test_symbolic.py index bd60f130e3..a471c9734c 100644 --- a/test/unit/test_symbolic.py +++ b/test/unit/test_symbolic.py @@ -119,6 +119,10 @@ class TestSymbolic(unittest.TestCase): def test_sum_div_some_factor(self): self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "(((a*5)//2)+(b*2))") + def test_sum_div_some_partial_factor(self): + self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)") + self.helper_test_variable(Variable.sum([Variable.num(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)") + def test_sum_div_no_factor(self): self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)") diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 2b9b459313..49027d2f31 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -186,28 +186,23 @@ class SumNode(RedNode): def __floordiv__(self, b: int, factoring_allowed=True): if b == 1: return self if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed) - factors: List[Node] = [] - nofactor_mul: List[Node] = [] - nofactor_nonmul: List[Node] = [] + fully_divided: List[Node] = [] + rest: List[Node] = [] + _gcd = b + divisor = 1 for x in self.flat_components: - if x.__class__ is NumNode and x.b%b == 0: factors.append(x) - elif x.__class__ is MulNode: factors.append(x) if x.b%b == 0 else nofactor_mul.append(x) - else: nofactor_nonmul.append(x) - - if factors: # factor out largest possible gcd - factor_term = [x.a * x.b//b if isinstance(x, MulNode) else NumNode(x.b//b) for x in factors] - if nofactor_mul and not nofactor_nonmul: - gcds = [gcd(x.b, b) for x in nofactor_mul] - if (t := min(gcds)) > 1 and all(x.b%t == 0 for x in nofactor_mul): - nofactor_term = [Node.sum([x.a * x.b//t for x in nofactor_mul if isinstance(x, MulNode)])//(b//t)] # mypy wants the isinstance + if x.__class__ in (NumNode, MulNode): + if x.b%b == 0: fully_divided.append(x//b) else: - nofactor_term = [Node.sum(nofactor_mul)//b] if nofactor_mul else [] + rest.append(x) + _gcd = gcd(_gcd, x.b) + if x.__class__ == MulNode and divisor == 1 and b%x.b == 0: divisor = x.b else: - nofactor_term = [Node.sum(nofactor_mul+nofactor_nonmul)//b] if nofactor_mul + nofactor_nonmul else [] - return Node.sum(factor_term + nofactor_term) - for m in nofactor_mul: - if m.b > 1 and b%m.b == 0: return (self//m.b)//(b//m.b) - return Node.__floordiv__(self, b, factoring_allowed) + rest.append(x) + _gcd = 1 + if _gcd > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(_gcd) // (b//_gcd) + if divisor > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(divisor) // (b//divisor) + return Node.sum(fully_divided) + Node.__floordiv__(Node.sum(rest), b) def __mod__(self, b: int): new_nodes: List[Node] = []