simpler ModNode.__mod__ and ModNode.__floordiv__ (#2983)

`gcd(self.b, b) == b` is equivalent to `self.b % b == 0`.
use the same condition and format in __floordiv__ too.
This commit is contained in:
chenyu
2024-01-02 18:52:42 -05:00
committed by GitHub
parent c07907e644
commit 0dd3ca59cd

View File

@@ -168,8 +168,7 @@ class OpNode(Node):
class LtNode(OpNode):
def __floordiv__(self, b: Union[Node, int], _=False): return (self.a//b) < (self.b//b)
def get_bounds(self) -> Tuple[int, int]:
if isinstance(self.b, int):
return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1)
if isinstance(self.b, int): return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1)
return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min >= self.b.max else (0, 1)
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node:
return self.a.substitute(var_vals) < (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
@@ -199,10 +198,9 @@ class DivNode(OpNode):
class ModNode(OpNode):
def __mod__(self, b: Union[Node, int]):
if isinstance(b, Node) or isinstance(self.b, Node): return Node.__mod__(self, b)
return self.a % b if gcd(self.b, b) == b else Node.__mod__(self, b)
return self.a % b if self.b % b == 0 else Node.__mod__(self, b)
def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
if (self.b % b == 0): return (self.a//b) % (self.b//b) # put the div inside mod
return Node.__floordiv__(self, b, factoring_allowed)
return (self.a//b) % (self.b//b) if self.b % b == 0 else Node.__floordiv__(self, b, factoring_allowed)
def get_bounds(self) -> Tuple[int, int]:
assert self.a.min >= 0 and isinstance(self.b, int)
return (0, self.b-1) if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b) else (self.a.min%self.b, self.a.max%self.b) # noqa: E501
@@ -287,11 +285,10 @@ class SumNode(RedNode):
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return Node.sum([node.substitute(var_vals) for node in self.nodes])
# recursively expand sumnode components
# TODO: can remove this if there's no SumNode inside SumNode
@property
def flat_components(self): # recursively expand sumnode components
new_nodes = []
for x in self.nodes: new_nodes += (x.flat_components if isinstance(x, SumNode) else [x])
return new_nodes
def flat_components(self): return [y for x in self.nodes for y in (x.flat_components if isinstance(x, SumNode) else [x])]
class AndNode(RedNode):
def __floordiv__(self, b: Union[Node, int], _=True): return Node.ands([x//b for x in self.nodes])