diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 8382e96fd9..f4778deee7 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -1,9 +1,8 @@ from __future__ import annotations from abc import abstractmethod -from collections import defaultdict -import math, functools +import functools +from math import gcd from typing import List, Dict, Callable, Tuple, Type, Union -from tinygrad.helpers import partition, all_same # NOTE: Python has different behavior for negative mod and floor div than c # symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod @@ -14,7 +13,7 @@ class Node: max: int def render(self, ops=None, ctx=None) -> str: if ops is None: ops = render_python - assert isinstance(self, (Variable, NumNode)) or self.min != self.max + assert self.__class__ in (Variable, NumNode) or self.min != self.max return ops[type(self)](self, ops, ctx) @functools.cached_property def key(self) -> str: return self.render(ctx="DEBUG") @@ -57,49 +56,42 @@ class Node: @staticmethod def num(num:int) -> Node: return NumNode(num) + @staticmethod + def factorize(nodes:List[Node]): + mul_groups: Dict[Node, int] = {} + for x in nodes: + a,b = (x.a,x.b) if isinstance(x, MulNode) else (x,1) + mul_groups[a] = mul_groups.get(a, 0) + b + return [MulNode(a, b_sum) if b_sum != 1 else a for a, b_sum in mul_groups.items() if b_sum != 0] + @staticmethod def sum(nodes:List[Node]) -> Node: + nodes = [x for x in nodes if x.max or x.min] + if not nodes: return NumNode(0) + if len(nodes) == 1: return nodes[0] + new_nodes: List[Node] = [] - sum_nodes: List[SumNode] = [] - num_nodes: List[NumNode] = [] - mul_nodes: List[MulNode] = [] + num_node_sum = 0 + + # flatten all sumnodes and gather numnodes for node in nodes: - if isinstance(node, NumNode): - num_nodes.append(node) - elif isinstance(node, MulNode): - mul_nodes.append(node) - elif isinstance(node, SumNode): # expand any sums inside one sum - sum_nodes.append(node) - else: - new_nodes.append(node) + if node.__class__ not in (NumNode, SumNode): new_nodes.append(node) + elif node.__class__ is NumNode: num_node_sum += node.b + elif isinstance(node, SumNode): # mypy wants the isinstance + for sub_node in node.flat_components: + if sub_node.__class__ is NumNode: num_node_sum += sub_node.b + else: new_nodes.append(sub_node) - # expand any sums inside one sum - if sum_nodes: - new_nodes.extend(num_nodes) - new_nodes.extend(mul_nodes) - for x in sum_nodes: new_nodes += x.nodes - return Variable.sum(new_nodes) - - # combine any numbers inside a sum - if num_nodes: - new_nodes.append(NumNode(sum([x.b for x in num_nodes]))) - - # combine any MulNodes that factorize (big hack sticking the MulNode(x, 1) on things) - mul_nodes += [MulNode(x, 1) for x in new_nodes] - mul_groups: Dict[str, Tuple[Node, List[MulNode]]] = defaultdict(lambda: (Node(), [])) - for node in mul_nodes: #NOTE can we somehow avoid rendering here? - key = node.a.render() - mul_groups[key] = (node.a, mul_groups[key][1] + [node]) - mul_nodes = [k * sum(x.b for x in g) for k, g in mul_groups.values()] - new_nodes = [x if not isinstance(x, MulNode) or x.b != 1 else x.a for x in mul_nodes] - - # filter 0s - new_nodes = [x for x in new_nodes if x.min != 0 or x.max != 0] - return create_rednode(SumNode, new_nodes) if len(new_nodes) > 1 else (new_nodes[0] if len(new_nodes) == 1 else NumNode(0)) + if len(new_nodes) > 1 and len(set([x.a if isinstance(x, MulNode) else x for x in new_nodes])) < len(new_nodes): + new_nodes = Node.factorize(new_nodes) + if num_node_sum: new_nodes.append(NumNode(num_node_sum)) + return create_rednode(SumNode, new_nodes) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0) @staticmethod def ands(nodes:List[Node]) -> Node: - if any((x.min == 0 and x.max == 0) for x in nodes): return NumNode(0) + if not nodes: return NumNode(1) + if len(nodes) == 1: return nodes[0] + if any([x.min == x.max == 0 for x in nodes]): return NumNode(0) # filter 1s nodes = [x for x in nodes if x.min != x.max] @@ -164,40 +156,46 @@ class RedNode(Node): def __init__(self, nodes:List[Node]): self.nodes = nodes class SumNode(RedNode): - def __mul__(self, b: int): return Variable.sum([x*b for x in self.nodes]) # distribute mul into sum + def __mul__(self, b: int): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum def __floordiv__(self, b: int, factoring_allowed=True): + if b == 1: return self if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed) - factors, tmp_nofactor = partition(self.nodes, lambda x: (isinstance(x, (MulNode, NumNode))) and x.b%b == 0) - nofactor = [] - # ugh, i doubt this is universally right - for x in tmp_nofactor: - if isinstance(x, NumNode): - if (x.b%b) != x.b: - factors.append(Variable.num(x.b - (x.b%b))) # python does floor division - nofactor.append(Variable.num(x.b%b)) - else: - nofactor.append(x) - gcd = [math.gcd(x.b, b) if isinstance(x, (MulNode, NumNode)) else None for x in nofactor] - if len(factors) > 0: - # these don't have to be the same, just having a common factor - if len(gcd) > 0 and all_same(gcd) and gcd[0] is not None and gcd[0] > 1: - nofactor_term = Variable.sum([(x.a * (x.b//gcd[0])) if isinstance(x, MulNode) else Variable.num(x.b//gcd[0]) for x in nofactor])//(b//gcd[0]) - else: - nofactor_term = Variable.sum(nofactor)//b - return Variable.sum([(x.a * (x.b//b)) if isinstance(x, MulNode) else Variable.num(x.b//b) for x in factors] + [nofactor_term]) - else: - muls = [x.b for x in nofactor if isinstance(x, MulNode)] - for m in muls: - if m > 1 and b%m == 0: - return (self//m)//(b//m) - return Node.__floordiv__(self, b, factoring_allowed) + factors: List[Node] = [] + nofactor_mul: List[Node] = [] + nofactor_nonmul: List[Node] = [] + for x in self.flat_components: + if x.__class__ is NumNode and x.b%b == 0: factors.append(x) + elif x.__class__ is MulNode: factors.append(x) if x.b%b == 0 else nofactor_mul.append(x) + else: nofactor_nonmul.append(x) + + if factors: # factor out largest possible gcd + factor_term = [x.a * x.b//b if isinstance(x, MulNode) else NumNode(x.b//b) for x in factors] + if nofactor_mul and not nofactor_nonmul: + gcds = [gcd(x.b, b) for x in nofactor_mul] + if (t := min(gcds)) > 1 and all([x.b%t == 0 for x in nofactor_mul]): + nofactor_term = [Node.sum([x.a * x.b//t for x in nofactor_mul if isinstance(x, MulNode)])//(b//t)] # mypy wants the isinstance + else: + nofactor_term = [Node.sum(nofactor_mul)//b] if nofactor_mul else [] + else: + nofactor_term = [Node.sum(nofactor_mul+nofactor_nonmul)//b] if nofactor_mul + nofactor_nonmul else [] + return Node.sum(factor_term + nofactor_term) + for m in nofactor_mul: + if m.b > 1 and b%m.b == 0: return (self//m.b)//(b//m.b) + return Node.__floordiv__(self, b, factoring_allowed) + def __mod__(self, b: int): new_nodes = [] for x in self.nodes: - if isinstance(x, NumNode): new_nodes.append(Variable.num(x.b%b)) + if x.__class__ is NumNode: new_nodes.append(Variable.num(x.b%b)) elif isinstance(x, MulNode): new_nodes.append(x.a * (x.b%b)) else: new_nodes.append(x) - return Node.__mod__(Variable.sum(new_nodes), b) + return Node.__mod__(Node.sum(new_nodes), b) + + @property + def flat_components(self): # recursively expand sumnode components + new_nodes = [] + for x in self.nodes: new_nodes += (x.flat_components if isinstance(x, SumNode) else [x]) + return new_nodes class AndNode(RedNode): def __mul__(self, b: int): Variable.ands([x*b for x in self.nodes])