mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
symbolic.py: faster Node.sum, faster SumNode.div (#1014)
* refactor: replace isinstance with class check where possible * refactor: faster partition * fix; flake8 * feat: rework node.sum, correct list typing * fix: typo * feat: refactor sum * fix: pylint * refactor: simpler sum and factorize * feat; clean up sumnode div, all cpu tests pass * feat: simplify floordiv, cache factorization * don't factor numnodes at all * python 3.8 functools does not yet have @cache * fix: restore assert * refactor, fix failing tests * fix: address review comments * feat: rework, add specialization, remove cache * fix: remove specialization * feat: no tuple conversion, faster loop --------- Co-authored-by: Roelof van Dijk <roelof.van.dijk@vitestro.com>
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
from __future__ import annotations
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
import math, functools
|
||||
import functools
|
||||
from math import gcd
|
||||
from typing import List, Dict, Callable, Tuple, Type, Union
|
||||
from tinygrad.helpers import partition, all_same
|
||||
|
||||
# NOTE: Python has different behavior for negative mod and floor div than c
|
||||
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
|
||||
@@ -14,7 +13,7 @@ class Node:
|
||||
max: int
|
||||
def render(self, ops=None, ctx=None) -> str:
|
||||
if ops is None: ops = render_python
|
||||
assert isinstance(self, (Variable, NumNode)) or self.min != self.max
|
||||
assert self.__class__ in (Variable, NumNode) or self.min != self.max
|
||||
return ops[type(self)](self, ops, ctx)
|
||||
@functools.cached_property
|
||||
def key(self) -> str: return self.render(ctx="DEBUG")
|
||||
@@ -57,49 +56,42 @@ class Node:
|
||||
@staticmethod
|
||||
def num(num:int) -> Node: return NumNode(num)
|
||||
|
||||
@staticmethod
|
||||
def factorize(nodes:List[Node]):
|
||||
mul_groups: Dict[Node, int] = {}
|
||||
for x in nodes:
|
||||
a,b = (x.a,x.b) if isinstance(x, MulNode) else (x,1)
|
||||
mul_groups[a] = mul_groups.get(a, 0) + b
|
||||
return [MulNode(a, b_sum) if b_sum != 1 else a for a, b_sum in mul_groups.items() if b_sum != 0]
|
||||
|
||||
@staticmethod
|
||||
def sum(nodes:List[Node]) -> Node:
|
||||
nodes = [x for x in nodes if x.max or x.min]
|
||||
if not nodes: return NumNode(0)
|
||||
if len(nodes) == 1: return nodes[0]
|
||||
|
||||
new_nodes: List[Node] = []
|
||||
sum_nodes: List[SumNode] = []
|
||||
num_nodes: List[NumNode] = []
|
||||
mul_nodes: List[MulNode] = []
|
||||
num_node_sum = 0
|
||||
|
||||
# flatten all sumnodes and gather numnodes
|
||||
for node in nodes:
|
||||
if isinstance(node, NumNode):
|
||||
num_nodes.append(node)
|
||||
elif isinstance(node, MulNode):
|
||||
mul_nodes.append(node)
|
||||
elif isinstance(node, SumNode): # expand any sums inside one sum
|
||||
sum_nodes.append(node)
|
||||
else:
|
||||
new_nodes.append(node)
|
||||
if node.__class__ not in (NumNode, SumNode): new_nodes.append(node)
|
||||
elif node.__class__ is NumNode: num_node_sum += node.b
|
||||
elif isinstance(node, SumNode): # mypy wants the isinstance
|
||||
for sub_node in node.flat_components:
|
||||
if sub_node.__class__ is NumNode: num_node_sum += sub_node.b
|
||||
else: new_nodes.append(sub_node)
|
||||
|
||||
# expand any sums inside one sum
|
||||
if sum_nodes:
|
||||
new_nodes.extend(num_nodes)
|
||||
new_nodes.extend(mul_nodes)
|
||||
for x in sum_nodes: new_nodes += x.nodes
|
||||
return Variable.sum(new_nodes)
|
||||
|
||||
# combine any numbers inside a sum
|
||||
if num_nodes:
|
||||
new_nodes.append(NumNode(sum([x.b for x in num_nodes])))
|
||||
|
||||
# combine any MulNodes that factorize (big hack sticking the MulNode(x, 1) on things)
|
||||
mul_nodes += [MulNode(x, 1) for x in new_nodes]
|
||||
mul_groups: Dict[str, Tuple[Node, List[MulNode]]] = defaultdict(lambda: (Node(), []))
|
||||
for node in mul_nodes: #NOTE can we somehow avoid rendering here?
|
||||
key = node.a.render()
|
||||
mul_groups[key] = (node.a, mul_groups[key][1] + [node])
|
||||
mul_nodes = [k * sum(x.b for x in g) for k, g in mul_groups.values()]
|
||||
new_nodes = [x if not isinstance(x, MulNode) or x.b != 1 else x.a for x in mul_nodes]
|
||||
|
||||
# filter 0s
|
||||
new_nodes = [x for x in new_nodes if x.min != 0 or x.max != 0]
|
||||
return create_rednode(SumNode, new_nodes) if len(new_nodes) > 1 else (new_nodes[0] if len(new_nodes) == 1 else NumNode(0))
|
||||
if len(new_nodes) > 1 and len(set([x.a if isinstance(x, MulNode) else x for x in new_nodes])) < len(new_nodes):
|
||||
new_nodes = Node.factorize(new_nodes)
|
||||
if num_node_sum: new_nodes.append(NumNode(num_node_sum))
|
||||
return create_rednode(SumNode, new_nodes) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0)
|
||||
|
||||
@staticmethod
|
||||
def ands(nodes:List[Node]) -> Node:
|
||||
if any((x.min == 0 and x.max == 0) for x in nodes): return NumNode(0)
|
||||
if not nodes: return NumNode(1)
|
||||
if len(nodes) == 1: return nodes[0]
|
||||
if any([x.min == x.max == 0 for x in nodes]): return NumNode(0)
|
||||
|
||||
# filter 1s
|
||||
nodes = [x for x in nodes if x.min != x.max]
|
||||
@@ -164,40 +156,46 @@ class RedNode(Node):
|
||||
def __init__(self, nodes:List[Node]): self.nodes = nodes
|
||||
|
||||
class SumNode(RedNode):
|
||||
def __mul__(self, b: int): return Variable.sum([x*b for x in self.nodes]) # distribute mul into sum
|
||||
def __mul__(self, b: int): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
|
||||
def __floordiv__(self, b: int, factoring_allowed=True):
|
||||
if b == 1: return self
|
||||
if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
|
||||
factors, tmp_nofactor = partition(self.nodes, lambda x: (isinstance(x, (MulNode, NumNode))) and x.b%b == 0)
|
||||
nofactor = []
|
||||
# ugh, i doubt this is universally right
|
||||
for x in tmp_nofactor:
|
||||
if isinstance(x, NumNode):
|
||||
if (x.b%b) != x.b:
|
||||
factors.append(Variable.num(x.b - (x.b%b))) # python does floor division
|
||||
nofactor.append(Variable.num(x.b%b))
|
||||
else:
|
||||
nofactor.append(x)
|
||||
gcd = [math.gcd(x.b, b) if isinstance(x, (MulNode, NumNode)) else None for x in nofactor]
|
||||
if len(factors) > 0:
|
||||
# these don't have to be the same, just having a common factor
|
||||
if len(gcd) > 0 and all_same(gcd) and gcd[0] is not None and gcd[0] > 1:
|
||||
nofactor_term = Variable.sum([(x.a * (x.b//gcd[0])) if isinstance(x, MulNode) else Variable.num(x.b//gcd[0]) for x in nofactor])//(b//gcd[0])
|
||||
else:
|
||||
nofactor_term = Variable.sum(nofactor)//b
|
||||
return Variable.sum([(x.a * (x.b//b)) if isinstance(x, MulNode) else Variable.num(x.b//b) for x in factors] + [nofactor_term])
|
||||
else:
|
||||
muls = [x.b for x in nofactor if isinstance(x, MulNode)]
|
||||
for m in muls:
|
||||
if m > 1 and b%m == 0:
|
||||
return (self//m)//(b//m)
|
||||
return Node.__floordiv__(self, b, factoring_allowed)
|
||||
factors: List[Node] = []
|
||||
nofactor_mul: List[Node] = []
|
||||
nofactor_nonmul: List[Node] = []
|
||||
for x in self.flat_components:
|
||||
if x.__class__ is NumNode and x.b%b == 0: factors.append(x)
|
||||
elif x.__class__ is MulNode: factors.append(x) if x.b%b == 0 else nofactor_mul.append(x)
|
||||
else: nofactor_nonmul.append(x)
|
||||
|
||||
if factors: # factor out largest possible gcd
|
||||
factor_term = [x.a * x.b//b if isinstance(x, MulNode) else NumNode(x.b//b) for x in factors]
|
||||
if nofactor_mul and not nofactor_nonmul:
|
||||
gcds = [gcd(x.b, b) for x in nofactor_mul]
|
||||
if (t := min(gcds)) > 1 and all([x.b%t == 0 for x in nofactor_mul]):
|
||||
nofactor_term = [Node.sum([x.a * x.b//t for x in nofactor_mul if isinstance(x, MulNode)])//(b//t)] # mypy wants the isinstance
|
||||
else:
|
||||
nofactor_term = [Node.sum(nofactor_mul)//b] if nofactor_mul else []
|
||||
else:
|
||||
nofactor_term = [Node.sum(nofactor_mul+nofactor_nonmul)//b] if nofactor_mul + nofactor_nonmul else []
|
||||
return Node.sum(factor_term + nofactor_term)
|
||||
for m in nofactor_mul:
|
||||
if m.b > 1 and b%m.b == 0: return (self//m.b)//(b//m.b)
|
||||
return Node.__floordiv__(self, b, factoring_allowed)
|
||||
|
||||
def __mod__(self, b: int):
|
||||
new_nodes = []
|
||||
for x in self.nodes:
|
||||
if isinstance(x, NumNode): new_nodes.append(Variable.num(x.b%b))
|
||||
if x.__class__ is NumNode: new_nodes.append(Variable.num(x.b%b))
|
||||
elif isinstance(x, MulNode): new_nodes.append(x.a * (x.b%b))
|
||||
else: new_nodes.append(x)
|
||||
return Node.__mod__(Variable.sum(new_nodes), b)
|
||||
return Node.__mod__(Node.sum(new_nodes), b)
|
||||
|
||||
@property
|
||||
def flat_components(self): # recursively expand sumnode components
|
||||
new_nodes = []
|
||||
for x in self.nodes: new_nodes += (x.flat_components if isinstance(x, SumNode) else [x])
|
||||
return new_nodes
|
||||
|
||||
class AndNode(RedNode):
|
||||
def __mul__(self, b: int): Variable.ands([x*b for x in self.nodes])
|
||||
|
||||
Reference in New Issue
Block a user