mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Simplify symbolic.SumNode.__floordiv__ logic (#1220)
This commit is contained in:
@@ -119,6 +119,10 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_sum_div_some_factor(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "(((a*5)//2)+(b*2))")
|
||||
|
||||
def test_sum_div_some_partial_factor(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)")
|
||||
self.helper_test_variable(Variable.sum([Variable.num(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)")
|
||||
|
||||
def test_sum_div_no_factor(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)")
|
||||
|
||||
|
||||
@@ -186,28 +186,23 @@ class SumNode(RedNode):
|
||||
def __floordiv__(self, b: int, factoring_allowed=True):
|
||||
if b == 1: return self
|
||||
if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
|
||||
factors: List[Node] = []
|
||||
nofactor_mul: List[Node] = []
|
||||
nofactor_nonmul: List[Node] = []
|
||||
fully_divided: List[Node] = []
|
||||
rest: List[Node] = []
|
||||
_gcd = b
|
||||
divisor = 1
|
||||
for x in self.flat_components:
|
||||
if x.__class__ is NumNode and x.b%b == 0: factors.append(x)
|
||||
elif x.__class__ is MulNode: factors.append(x) if x.b%b == 0 else nofactor_mul.append(x)
|
||||
else: nofactor_nonmul.append(x)
|
||||
|
||||
if factors: # factor out largest possible gcd
|
||||
factor_term = [x.a * x.b//b if isinstance(x, MulNode) else NumNode(x.b//b) for x in factors]
|
||||
if nofactor_mul and not nofactor_nonmul:
|
||||
gcds = [gcd(x.b, b) for x in nofactor_mul]
|
||||
if (t := min(gcds)) > 1 and all(x.b%t == 0 for x in nofactor_mul):
|
||||
nofactor_term = [Node.sum([x.a * x.b//t for x in nofactor_mul if isinstance(x, MulNode)])//(b//t)] # mypy wants the isinstance
|
||||
if x.__class__ in (NumNode, MulNode):
|
||||
if x.b%b == 0: fully_divided.append(x//b)
|
||||
else:
|
||||
nofactor_term = [Node.sum(nofactor_mul)//b] if nofactor_mul else []
|
||||
rest.append(x)
|
||||
_gcd = gcd(_gcd, x.b)
|
||||
if x.__class__ == MulNode and divisor == 1 and b%x.b == 0: divisor = x.b
|
||||
else:
|
||||
nofactor_term = [Node.sum(nofactor_mul+nofactor_nonmul)//b] if nofactor_mul + nofactor_nonmul else []
|
||||
return Node.sum(factor_term + nofactor_term)
|
||||
for m in nofactor_mul:
|
||||
if m.b > 1 and b%m.b == 0: return (self//m.b)//(b//m.b)
|
||||
return Node.__floordiv__(self, b, factoring_allowed)
|
||||
rest.append(x)
|
||||
_gcd = 1
|
||||
if _gcd > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(_gcd) // (b//_gcd)
|
||||
if divisor > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(divisor) // (b//divisor)
|
||||
return Node.sum(fully_divided) + Node.__floordiv__(Node.sum(rest), b)
|
||||
|
||||
def __mod__(self, b: int):
|
||||
new_nodes: List[Node] = []
|
||||
|
||||
Reference in New Issue
Block a user