This commit is contained in:
George Hotz
2023-08-21 21:19:16 -07:00
committed by GitHub
parent c64c47a6ae
commit 86a32ffb1a
2 changed files with 15 additions and 1 deletions

View File

@@ -173,6 +173,9 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(Variable("a", 0, 6) < 3, 0, 1, "(a<3)")
self.helper_test_variable(Variable("a", 0, 6) < 8, 1, 1, "1")
def test_lt_sum_remove(self):
self.helper_test_variable((Variable("a", 0, 6) + 2) < 3, 0, 1, "(a<1)")
def test_and_fold(self):
self.helper_test_variable(Variable.ands([Variable.num(0), Variable("a", 0, 1)]), 0, 0, "0")

View File

@@ -151,7 +151,8 @@ class Variable(Node):
class NumNode(Node):
def __init__(self, num:int):
self.b, self.min, self.max = num, num, num
self.b:int = num
self.min, self.max = num, num
def __int__(self): return self.b
def __eq__(self, other): return self.b == other
def __hash__(self): return self.hash # needed with __eq__ override
@@ -253,6 +254,16 @@ class SumNode(RedNode):
else: new_nodes.append(x)
return Node.__mod__(Node.sum(new_nodes), b)
def __lt__(self, b:Union[Node,int]):
if isinstance(b, int):
new_sum = []
for x in self.nodes:
# TODO: should we just force the last one to always be the number
if isinstance(x, NumNode): b -= x.b
else: new_sum.append(x)
return Node.__lt__(Node.sum(new_sum), b)
return Node.__lt__(self, b)
@property
def flat_components(self): # recursively expand sumnode components
new_nodes = []