simpler SumNode.__mod__ (#2979)

* simpler SumNode.__mod__

delegate simplification to individual node

* ModNode.__mod__ simplification case

* Revert "ModNode.__mod__ simplification case"

This reverts commit 73a42205a8.
This commit is contained in:
chenyu
2024-01-02 15:09:15 -05:00
committed by GitHub
parent 91ddda244f
commit 878e869663

View File

@@ -184,9 +184,7 @@ class MulNode(OpNode):
if self.b % b == 0: return self.a*(self.b//b)
if b % self.b == 0 and self.b > 0: return self.a//(b//self.b)
return Node.__floordiv__(self, b, factoring_allowed)
def __mod__(self, b: Union[Node, int]):
a = (self.a * (self.b%b))
return Node.__mod__(a, b)
def __mod__(self, b: Union[Node, int]): return Node.__mod__(self.a * (self.b%b), b)
def get_bounds(self) -> Tuple[int, int]: return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node:
return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
@@ -262,8 +260,7 @@ class SumNode(RedNode):
if isinstance(b, Node) and (b - self).min > 0: return self # b - self simplifies the node
new_nodes: List[Node] = []
for x in self.nodes:
if x.__class__ is NumNode: new_nodes.append(NumNode(x.b%b))
elif isinstance(x, MulNode): new_nodes.append(x.a * (x.b%b))
if x.__class__ in (NumNode, MulNode): new_nodes.append(x%b) # might simplify
else: new_nodes.append(x)
return Node.__mod__(Node.sum(new_nodes), b)