mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
simpler SumNode.__mod__ (#2979)
* simpler SumNode.__mod__
delegate simplification to individual node
* ModNode.__mod__ simplification case
* Revert "ModNode.__mod__ simplification case"
This reverts commit 73a42205a8.
This commit is contained in:
@@ -184,9 +184,7 @@ class MulNode(OpNode):
|
||||
if self.b % b == 0: return self.a*(self.b//b)
|
||||
if b % self.b == 0 and self.b > 0: return self.a//(b//self.b)
|
||||
return Node.__floordiv__(self, b, factoring_allowed)
|
||||
def __mod__(self, b: Union[Node, int]):
|
||||
a = (self.a * (self.b%b))
|
||||
return Node.__mod__(a, b)
|
||||
def __mod__(self, b: Union[Node, int]): return Node.__mod__(self.a * (self.b%b), b)
|
||||
def get_bounds(self) -> Tuple[int, int]: return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node:
|
||||
return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
|
||||
@@ -262,8 +260,7 @@ class SumNode(RedNode):
|
||||
if isinstance(b, Node) and (b - self).min > 0: return self # b - self simplifies the node
|
||||
new_nodes: List[Node] = []
|
||||
for x in self.nodes:
|
||||
if x.__class__ is NumNode: new_nodes.append(NumNode(x.b%b))
|
||||
elif isinstance(x, MulNode): new_nodes.append(x.a * (x.b%b))
|
||||
if x.__class__ in (NumNode, MulNode): new_nodes.append(x%b) # might simplify
|
||||
else: new_nodes.append(x)
|
||||
return Node.__mod__(Node.sum(new_nodes), b)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user