diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 6ee8b85957..b0013f5bbf 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -93,57 +93,34 @@ class NumNode(Node): self.b, self.min, self.max = num, num, num class OpNode(Node): + def __init__(self, a:Node, b:int): + self.a, self.b = a,b + self.min, self.max = self.minmax(a,b) + @staticmethod + def minmax(a, b): raise NotImplementedError() @property def expr(self): return f"({self.a}{self.op}{self.b})" class RedNode(Node): + def __init__(self, nodes:List[Node]): + self.nodes = nodes + self.min, self.max = self.minmax(nodes) + @staticmethod + def minmax(nodes): raise NotImplementedError() @property def expr(self): return f"({self.op.join([str(x) for x in self.nodes])})" # operation nodes -class MulNode(OpNode): - op, minf, maxf = "*", lambda a,b: a.min*b, lambda a,b: a.max*b - def __init__(self, a:Node, b:int): - self.a, self.b = a, b - self.min, self.max = a.min*b, a.max*b - -class DivNode(OpNode): - op = "//" - def __init__(self, a:Node, b:int): - self.a, self.b = a, b - self.min, self.max = int(a.min/b), int(a.max/b) - -class ModNode(OpNode): - op = "%" - def __init__(self, a:Node, b:int): - self.a, self.b = a, b - self.min, self.max = min(a.min, 0), max(a.max, b-1) - -class GeNode(OpNode): - op = ">=" - def __init__(self, a:Node, b:int): - self.a, self.b = a, b - self.min, self.max = 0, 1 - -class LtNode(OpNode): - op = "<" - def __init__(self, a:Node, b:int): - self.a, self.b = a, b - self.min, self.max = 0, 1 +class MulNode(OpNode): op, minmax = "*", staticmethod(lambda a,b: (a.min*b, a.max*b)) +class DivNode(OpNode): op, minmax = "//", staticmethod(lambda a,b: (int(a.min/b), int(a.max/b))) +class ModNode(OpNode): op, minmax = "%", staticmethod(lambda a,b: (min(a.min, 0), max(a.max, b-1))) +class GeNode(OpNode): op, minmax = ">=", staticmethod(lambda a,b: (0,1)) +class LtNode(OpNode): op, minmax = "<", staticmethod(lambda a,b: (0,1)) # reduce nodes -class SumNode(RedNode): - op = "+" - def __init__(self, nodes:List[Node]): - self.nodes = nodes - self.min, self.max = sum([x.min for x in nodes]), sum([x.max for x in nodes]) - -class AndNode(RedNode): - op = "&&" - def __init__(self, nodes:List[Node]): - self.nodes = nodes - self.min, self.max = min([x.min for x in nodes]), max([x.max for x in nodes]) +class SumNode(RedNode): op, minmax = "+", staticmethod(lambda nodes: (sum([x.min for x in nodes]), sum([x.max for x in nodes]))) +class AndNode(RedNode): op, minmax = "&&", staticmethod(lambda nodes: (min([x.min for x in nodes]), max([x.max for x in nodes])))