diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 9cea7cb586..f67806d2af 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -168,8 +168,7 @@ class OpNode(Node): class LtNode(OpNode): def __floordiv__(self, b: Union[Node, int], _=False): return (self.a//b) < (self.b//b) def get_bounds(self) -> Tuple[int, int]: - if isinstance(self.b, int): - return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1) + if isinstance(self.b, int): return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1) return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min >= self.b.max else (0, 1) def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) < (self.b if isinstance(self.b, int) else self.b.substitute(var_vals)) @@ -199,10 +198,9 @@ class DivNode(OpNode): class ModNode(OpNode): def __mod__(self, b: Union[Node, int]): if isinstance(b, Node) or isinstance(self.b, Node): return Node.__mod__(self, b) - return self.a % b if gcd(self.b, b) == b else Node.__mod__(self, b) + return self.a % b if self.b % b == 0 else Node.__mod__(self, b) def __floordiv__(self, b: Union[Node, int], factoring_allowed=True): - if (self.b % b == 0): return (self.a//b) % (self.b//b) # put the div inside mod - return Node.__floordiv__(self, b, factoring_allowed) + return (self.a//b) % (self.b//b) if self.b % b == 0 else Node.__floordiv__(self, b, factoring_allowed) def get_bounds(self) -> Tuple[int, int]: assert self.a.min >= 0 and isinstance(self.b, int) return (0, self.b-1) if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b) else (self.a.min%self.b, self.a.max%self.b) # noqa: E501 @@ -287,11 +285,10 @@ class SumNode(RedNode): def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return Node.sum([node.substitute(var_vals) for node in self.nodes]) + # recursively expand sumnode components + # TODO: can remove this if there's no SumNode inside SumNode @property - def flat_components(self): # recursively expand sumnode components - new_nodes = [] - for x in self.nodes: new_nodes += (x.flat_components if isinstance(x, SumNode) else [x]) - return new_nodes + def flat_components(self): return [y for x in self.nodes for y in (x.flat_components if isinstance(x, SumNode) else [x])] class AndNode(RedNode): def __floordiv__(self, b: Union[Node, int], _=True): return Node.ands([x//b for x in self.nodes])