mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
sum_combine_num
This commit is contained in:
@@ -73,8 +73,11 @@ class TestSymbolic(unittest.TestCase):
|
||||
|
||||
def test_mod_factor(self):
|
||||
# this is technically wrong, if b is 0 the output will be negative
|
||||
self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, -1, 9, "((-1+a)%28)")
|
||||
self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, -1, 27, "((-1+a)%28)")
|
||||
self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, -1, 9, "((a+-1)%28)")
|
||||
self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, -1, 27, "((a+-1)%28)")
|
||||
|
||||
def test_sum_combine_num(self):
|
||||
self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 10), Variable.num(23)]), -6, 4, "(a+-6)")
|
||||
|
||||
def test_div_factor(self):
|
||||
# TODO: this isn't right
|
||||
|
||||
@@ -65,6 +65,10 @@ class Node:
|
||||
|
||||
@staticmethod
|
||||
def sum(nodes:List[Node]) -> Node:
|
||||
nodes, num_nodes = partition(nodes, lambda x: not isinstance(x, NumNode))
|
||||
num_sum = sum([x.b for x in num_nodes])
|
||||
if num_sum != 0: nodes.append(NumNode(num_sum))
|
||||
|
||||
if any([isinstance(x, SumNode) for x in nodes]):
|
||||
nodes, sum_nodes = partition(nodes, lambda x: not isinstance(x, SumNode))
|
||||
for x in sum_nodes: nodes += x.nodes
|
||||
|
||||
Reference in New Issue
Block a user