mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
even simpler symbolic
This commit is contained in:
@@ -93,57 +93,34 @@ class NumNode(Node):
|
||||
self.b, self.min, self.max = num, num, num
|
||||
|
||||
class OpNode(Node):
|
||||
def __init__(self, a:Node, b:int):
|
||||
self.a, self.b = a,b
|
||||
self.min, self.max = self.minmax(a,b)
|
||||
@staticmethod
|
||||
def minmax(a, b): raise NotImplementedError()
|
||||
@property
|
||||
def expr(self):
|
||||
return f"({self.a}{self.op}{self.b})"
|
||||
|
||||
class RedNode(Node):
|
||||
def __init__(self, nodes:List[Node]):
|
||||
self.nodes = nodes
|
||||
self.min, self.max = self.minmax(nodes)
|
||||
@staticmethod
|
||||
def minmax(nodes): raise NotImplementedError()
|
||||
@property
|
||||
def expr(self):
|
||||
return f"({self.op.join([str(x) for x in self.nodes])})"
|
||||
|
||||
# operation nodes
|
||||
|
||||
class MulNode(OpNode):
|
||||
op, minf, maxf = "*", lambda a,b: a.min*b, lambda a,b: a.max*b
|
||||
def __init__(self, a:Node, b:int):
|
||||
self.a, self.b = a, b
|
||||
self.min, self.max = a.min*b, a.max*b
|
||||
|
||||
class DivNode(OpNode):
|
||||
op = "//"
|
||||
def __init__(self, a:Node, b:int):
|
||||
self.a, self.b = a, b
|
||||
self.min, self.max = int(a.min/b), int(a.max/b)
|
||||
|
||||
class ModNode(OpNode):
|
||||
op = "%"
|
||||
def __init__(self, a:Node, b:int):
|
||||
self.a, self.b = a, b
|
||||
self.min, self.max = min(a.min, 0), max(a.max, b-1)
|
||||
|
||||
class GeNode(OpNode):
|
||||
op = ">="
|
||||
def __init__(self, a:Node, b:int):
|
||||
self.a, self.b = a, b
|
||||
self.min, self.max = 0, 1
|
||||
|
||||
class LtNode(OpNode):
|
||||
op = "<"
|
||||
def __init__(self, a:Node, b:int):
|
||||
self.a, self.b = a, b
|
||||
self.min, self.max = 0, 1
|
||||
class MulNode(OpNode): op, minmax = "*", staticmethod(lambda a,b: (a.min*b, a.max*b))
|
||||
class DivNode(OpNode): op, minmax = "//", staticmethod(lambda a,b: (int(a.min/b), int(a.max/b)))
|
||||
class ModNode(OpNode): op, minmax = "%", staticmethod(lambda a,b: (min(a.min, 0), max(a.max, b-1)))
|
||||
class GeNode(OpNode): op, minmax = ">=", staticmethod(lambda a,b: (0,1))
|
||||
class LtNode(OpNode): op, minmax = "<", staticmethod(lambda a,b: (0,1))
|
||||
|
||||
# reduce nodes
|
||||
|
||||
class SumNode(RedNode):
|
||||
op = "+"
|
||||
def __init__(self, nodes:List[Node]):
|
||||
self.nodes = nodes
|
||||
self.min, self.max = sum([x.min for x in nodes]), sum([x.max for x in nodes])
|
||||
|
||||
class AndNode(RedNode):
|
||||
op = "&&"
|
||||
def __init__(self, nodes:List[Node]):
|
||||
self.nodes = nodes
|
||||
self.min, self.max = min([x.min for x in nodes]), max([x.max for x in nodes])
|
||||
class SumNode(RedNode): op, minmax = "+", staticmethod(lambda nodes: (sum([x.min for x in nodes]), sum([x.max for x in nodes])))
|
||||
class AndNode(RedNode): op, minmax = "&&", staticmethod(lambda nodes: (min([x.min for x in nodes]), max([x.max for x in nodes])))
|
||||
|
||||
Reference in New Issue
Block a user