symbolic.py: faster Node.sum, faster SumNode.div (#1014)

* refactor: replace isinstance with class check where possible

* refactor: faster partition

* fix; flake8

* feat: rework node.sum, correct list typing

* fix: typo

* feat: refactor sum

* fix: pylint

* refactor: simpler sum and factorize

* feat; clean up sumnode div, all cpu tests pass

* feat: simplify floordiv, cache factorization

* don't factor numnodes at all

* python 3.8 functools does not yet have @cache

* fix: restore assert

* refactor, fix failing tests

* fix: address review comments

* feat: rework, add specialization, remove cache

* fix: remove specialization

* feat: no tuple conversion, faster loop

---------

Co-authored-by: Roelof van Dijk <roelof.van.dijk@vitestro.com>
This commit is contained in:
Roelof van Dijk
2023-06-26 18:47:17 +02:00
committed by GitHub
parent 52b7105f87
commit c604ef4beb

View File

@@ -1,9 +1,8 @@
from __future__ import annotations
from abc import abstractmethod
from collections import defaultdict
import math, functools
import functools
from math import gcd
from typing import List, Dict, Callable, Tuple, Type, Union
from tinygrad.helpers import partition, all_same
# NOTE: Python has different behavior for negative mod and floor div than c
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
@@ -14,7 +13,7 @@ class Node:
max: int
def render(self, ops=None, ctx=None) -> str:
if ops is None: ops = render_python
assert isinstance(self, (Variable, NumNode)) or self.min != self.max
assert self.__class__ in (Variable, NumNode) or self.min != self.max
return ops[type(self)](self, ops, ctx)
@functools.cached_property
def key(self) -> str: return self.render(ctx="DEBUG")
@@ -57,49 +56,42 @@ class Node:
@staticmethod
def num(num:int) -> Node: return NumNode(num)
@staticmethod
def factorize(nodes:List[Node]):
mul_groups: Dict[Node, int] = {}
for x in nodes:
a,b = (x.a,x.b) if isinstance(x, MulNode) else (x,1)
mul_groups[a] = mul_groups.get(a, 0) + b
return [MulNode(a, b_sum) if b_sum != 1 else a for a, b_sum in mul_groups.items() if b_sum != 0]
@staticmethod
def sum(nodes:List[Node]) -> Node:
nodes = [x for x in nodes if x.max or x.min]
if not nodes: return NumNode(0)
if len(nodes) == 1: return nodes[0]
new_nodes: List[Node] = []
sum_nodes: List[SumNode] = []
num_nodes: List[NumNode] = []
mul_nodes: List[MulNode] = []
num_node_sum = 0
# flatten all sumnodes and gather numnodes
for node in nodes:
if isinstance(node, NumNode):
num_nodes.append(node)
elif isinstance(node, MulNode):
mul_nodes.append(node)
elif isinstance(node, SumNode): # expand any sums inside one sum
sum_nodes.append(node)
else:
new_nodes.append(node)
if node.__class__ not in (NumNode, SumNode): new_nodes.append(node)
elif node.__class__ is NumNode: num_node_sum += node.b
elif isinstance(node, SumNode): # mypy wants the isinstance
for sub_node in node.flat_components:
if sub_node.__class__ is NumNode: num_node_sum += sub_node.b
else: new_nodes.append(sub_node)
# expand any sums inside one sum
if sum_nodes:
new_nodes.extend(num_nodes)
new_nodes.extend(mul_nodes)
for x in sum_nodes: new_nodes += x.nodes
return Variable.sum(new_nodes)
# combine any numbers inside a sum
if num_nodes:
new_nodes.append(NumNode(sum([x.b for x in num_nodes])))
# combine any MulNodes that factorize (big hack sticking the MulNode(x, 1) on things)
mul_nodes += [MulNode(x, 1) for x in new_nodes]
mul_groups: Dict[str, Tuple[Node, List[MulNode]]] = defaultdict(lambda: (Node(), []))
for node in mul_nodes: #NOTE can we somehow avoid rendering here?
key = node.a.render()
mul_groups[key] = (node.a, mul_groups[key][1] + [node])
mul_nodes = [k * sum(x.b for x in g) for k, g in mul_groups.values()]
new_nodes = [x if not isinstance(x, MulNode) or x.b != 1 else x.a for x in mul_nodes]
# filter 0s
new_nodes = [x for x in new_nodes if x.min != 0 or x.max != 0]
return create_rednode(SumNode, new_nodes) if len(new_nodes) > 1 else (new_nodes[0] if len(new_nodes) == 1 else NumNode(0))
if len(new_nodes) > 1 and len(set([x.a if isinstance(x, MulNode) else x for x in new_nodes])) < len(new_nodes):
new_nodes = Node.factorize(new_nodes)
if num_node_sum: new_nodes.append(NumNode(num_node_sum))
return create_rednode(SumNode, new_nodes) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0)
@staticmethod
def ands(nodes:List[Node]) -> Node:
if any((x.min == 0 and x.max == 0) for x in nodes): return NumNode(0)
if not nodes: return NumNode(1)
if len(nodes) == 1: return nodes[0]
if any([x.min == x.max == 0 for x in nodes]): return NumNode(0)
# filter 1s
nodes = [x for x in nodes if x.min != x.max]
@@ -164,40 +156,46 @@ class RedNode(Node):
def __init__(self, nodes:List[Node]): self.nodes = nodes
class SumNode(RedNode):
def __mul__(self, b: int): return Variable.sum([x*b for x in self.nodes]) # distribute mul into sum
def __mul__(self, b: int): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
def __floordiv__(self, b: int, factoring_allowed=True):
if b == 1: return self
if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
factors, tmp_nofactor = partition(self.nodes, lambda x: (isinstance(x, (MulNode, NumNode))) and x.b%b == 0)
nofactor = []
# ugh, i doubt this is universally right
for x in tmp_nofactor:
if isinstance(x, NumNode):
if (x.b%b) != x.b:
factors.append(Variable.num(x.b - (x.b%b))) # python does floor division
nofactor.append(Variable.num(x.b%b))
else:
nofactor.append(x)
gcd = [math.gcd(x.b, b) if isinstance(x, (MulNode, NumNode)) else None for x in nofactor]
if len(factors) > 0:
# these don't have to be the same, just having a common factor
if len(gcd) > 0 and all_same(gcd) and gcd[0] is not None and gcd[0] > 1:
nofactor_term = Variable.sum([(x.a * (x.b//gcd[0])) if isinstance(x, MulNode) else Variable.num(x.b//gcd[0]) for x in nofactor])//(b//gcd[0])
else:
nofactor_term = Variable.sum(nofactor)//b
return Variable.sum([(x.a * (x.b//b)) if isinstance(x, MulNode) else Variable.num(x.b//b) for x in factors] + [nofactor_term])
else:
muls = [x.b for x in nofactor if isinstance(x, MulNode)]
for m in muls:
if m > 1 and b%m == 0:
return (self//m)//(b//m)
return Node.__floordiv__(self, b, factoring_allowed)
factors: List[Node] = []
nofactor_mul: List[Node] = []
nofactor_nonmul: List[Node] = []
for x in self.flat_components:
if x.__class__ is NumNode and x.b%b == 0: factors.append(x)
elif x.__class__ is MulNode: factors.append(x) if x.b%b == 0 else nofactor_mul.append(x)
else: nofactor_nonmul.append(x)
if factors: # factor out largest possible gcd
factor_term = [x.a * x.b//b if isinstance(x, MulNode) else NumNode(x.b//b) for x in factors]
if nofactor_mul and not nofactor_nonmul:
gcds = [gcd(x.b, b) for x in nofactor_mul]
if (t := min(gcds)) > 1 and all([x.b%t == 0 for x in nofactor_mul]):
nofactor_term = [Node.sum([x.a * x.b//t for x in nofactor_mul if isinstance(x, MulNode)])//(b//t)] # mypy wants the isinstance
else:
nofactor_term = [Node.sum(nofactor_mul)//b] if nofactor_mul else []
else:
nofactor_term = [Node.sum(nofactor_mul+nofactor_nonmul)//b] if nofactor_mul + nofactor_nonmul else []
return Node.sum(factor_term + nofactor_term)
for m in nofactor_mul:
if m.b > 1 and b%m.b == 0: return (self//m.b)//(b//m.b)
return Node.__floordiv__(self, b, factoring_allowed)
def __mod__(self, b: int):
new_nodes = []
for x in self.nodes:
if isinstance(x, NumNode): new_nodes.append(Variable.num(x.b%b))
if x.__class__ is NumNode: new_nodes.append(Variable.num(x.b%b))
elif isinstance(x, MulNode): new_nodes.append(x.a * (x.b%b))
else: new_nodes.append(x)
return Node.__mod__(Variable.sum(new_nodes), b)
return Node.__mod__(Node.sum(new_nodes), b)
@property
def flat_components(self): # recursively expand sumnode components
new_nodes = []
for x in self.nodes: new_nodes += (x.flat_components if isinstance(x, SumNode) else [x])
return new_nodes
class AndNode(RedNode):
def __mul__(self, b: int): Variable.ands([x*b for x in self.nodes])