mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
lt sum (#1617)
This commit is contained in:
@@ -173,6 +173,9 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable(Variable("a", 0, 6) < 3, 0, 1, "(a<3)")
|
||||
self.helper_test_variable(Variable("a", 0, 6) < 8, 1, 1, "1")
|
||||
|
||||
def test_lt_sum_remove(self):
|
||||
self.helper_test_variable((Variable("a", 0, 6) + 2) < 3, 0, 1, "(a<1)")
|
||||
|
||||
def test_and_fold(self):
|
||||
self.helper_test_variable(Variable.ands([Variable.num(0), Variable("a", 0, 1)]), 0, 0, "0")
|
||||
|
||||
|
||||
@@ -151,7 +151,8 @@ class Variable(Node):
|
||||
|
||||
class NumNode(Node):
|
||||
def __init__(self, num:int):
|
||||
self.b, self.min, self.max = num, num, num
|
||||
self.b:int = num
|
||||
self.min, self.max = num, num
|
||||
def __int__(self): return self.b
|
||||
def __eq__(self, other): return self.b == other
|
||||
def __hash__(self): return self.hash # needed with __eq__ override
|
||||
@@ -253,6 +254,16 @@ class SumNode(RedNode):
|
||||
else: new_nodes.append(x)
|
||||
return Node.__mod__(Node.sum(new_nodes), b)
|
||||
|
||||
def __lt__(self, b:Union[Node,int]):
|
||||
if isinstance(b, int):
|
||||
new_sum = []
|
||||
for x in self.nodes:
|
||||
# TODO: should we just force the last one to always be the number
|
||||
if isinstance(x, NumNode): b -= x.b
|
||||
else: new_sum.append(x)
|
||||
return Node.__lt__(Node.sum(new_sum), b)
|
||||
return Node.__lt__(self, b)
|
||||
|
||||
@property
|
||||
def flat_components(self): # recursively expand sumnode components
|
||||
new_nodes = []
|
||||
|
||||
Reference in New Issue
Block a user