mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
simpler ModNode.__mod__ and ModNode.__floordiv__ (#2983)
`gcd(self.b, b) == b` is equivalent to `self.b % b == 0`. use the same condition and format in __floordiv__ too.
This commit is contained in:
@@ -168,8 +168,7 @@ class OpNode(Node):
|
||||
class LtNode(OpNode):
|
||||
def __floordiv__(self, b: Union[Node, int], _=False): return (self.a//b) < (self.b//b)
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
if isinstance(self.b, int):
|
||||
return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1)
|
||||
if isinstance(self.b, int): return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1)
|
||||
return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min >= self.b.max else (0, 1)
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node:
|
||||
return self.a.substitute(var_vals) < (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
|
||||
@@ -199,10 +198,9 @@ class DivNode(OpNode):
|
||||
class ModNode(OpNode):
|
||||
def __mod__(self, b: Union[Node, int]):
|
||||
if isinstance(b, Node) or isinstance(self.b, Node): return Node.__mod__(self, b)
|
||||
return self.a % b if gcd(self.b, b) == b else Node.__mod__(self, b)
|
||||
return self.a % b if self.b % b == 0 else Node.__mod__(self, b)
|
||||
def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
|
||||
if (self.b % b == 0): return (self.a//b) % (self.b//b) # put the div inside mod
|
||||
return Node.__floordiv__(self, b, factoring_allowed)
|
||||
return (self.a//b) % (self.b//b) if self.b % b == 0 else Node.__floordiv__(self, b, factoring_allowed)
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
assert self.a.min >= 0 and isinstance(self.b, int)
|
||||
return (0, self.b-1) if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b) else (self.a.min%self.b, self.a.max%self.b) # noqa: E501
|
||||
@@ -287,11 +285,10 @@ class SumNode(RedNode):
|
||||
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return Node.sum([node.substitute(var_vals) for node in self.nodes])
|
||||
|
||||
# recursively expand sumnode components
|
||||
# TODO: can remove this if there's no SumNode inside SumNode
|
||||
@property
|
||||
def flat_components(self): # recursively expand sumnode components
|
||||
new_nodes = []
|
||||
for x in self.nodes: new_nodes += (x.flat_components if isinstance(x, SumNode) else [x])
|
||||
return new_nodes
|
||||
def flat_components(self): return [y for x in self.nodes for y in (x.flat_components if isinstance(x, SumNode) else [x])]
|
||||
|
||||
class AndNode(RedNode):
|
||||
def __floordiv__(self, b: Union[Node, int], _=True): return Node.ands([x//b for x in self.nodes])
|
||||
|
||||
Reference in New Issue
Block a user