even simpler symbolic

This commit is contained in:
George Hotz
2023-02-06 22:47:00 -06:00
parent 8b05de1841
commit 7c5a5ecdac

View File

@@ -93,57 +93,34 @@ class NumNode(Node):
self.b, self.min, self.max = num, num, num
class OpNode(Node):
def __init__(self, a:Node, b:int):
self.a, self.b = a,b
self.min, self.max = self.minmax(a,b)
@staticmethod
def minmax(a, b): raise NotImplementedError()
@property
def expr(self):
return f"({self.a}{self.op}{self.b})"
class RedNode(Node):
def __init__(self, nodes:List[Node]):
self.nodes = nodes
self.min, self.max = self.minmax(nodes)
@staticmethod
def minmax(nodes): raise NotImplementedError()
@property
def expr(self):
return f"({self.op.join([str(x) for x in self.nodes])})"
# operation nodes
class MulNode(OpNode):
op, minf, maxf = "*", lambda a,b: a.min*b, lambda a,b: a.max*b
def __init__(self, a:Node, b:int):
self.a, self.b = a, b
self.min, self.max = a.min*b, a.max*b
class DivNode(OpNode):
op = "//"
def __init__(self, a:Node, b:int):
self.a, self.b = a, b
self.min, self.max = int(a.min/b), int(a.max/b)
class ModNode(OpNode):
op = "%"
def __init__(self, a:Node, b:int):
self.a, self.b = a, b
self.min, self.max = min(a.min, 0), max(a.max, b-1)
class GeNode(OpNode):
op = ">="
def __init__(self, a:Node, b:int):
self.a, self.b = a, b
self.min, self.max = 0, 1
class LtNode(OpNode):
op = "<"
def __init__(self, a:Node, b:int):
self.a, self.b = a, b
self.min, self.max = 0, 1
class MulNode(OpNode): op, minmax = "*", staticmethod(lambda a,b: (a.min*b, a.max*b))
class DivNode(OpNode): op, minmax = "//", staticmethod(lambda a,b: (int(a.min/b), int(a.max/b)))
class ModNode(OpNode): op, minmax = "%", staticmethod(lambda a,b: (min(a.min, 0), max(a.max, b-1)))
class GeNode(OpNode): op, minmax = ">=", staticmethod(lambda a,b: (0,1))
class LtNode(OpNode): op, minmax = "<", staticmethod(lambda a,b: (0,1))
# reduce nodes
class SumNode(RedNode):
op = "+"
def __init__(self, nodes:List[Node]):
self.nodes = nodes
self.min, self.max = sum([x.min for x in nodes]), sum([x.max for x in nodes])
class AndNode(RedNode):
op = "&&"
def __init__(self, nodes:List[Node]):
self.nodes = nodes
self.min, self.max = min([x.min for x in nodes]), max([x.max for x in nodes])
class SumNode(RedNode): op, minmax = "+", staticmethod(lambda nodes: (sum([x.min for x in nodes]), sum([x.max for x in nodes])))
class AndNode(RedNode): op, minmax = "&&", staticmethod(lambda nodes: (min([x.min for x in nodes]), max([x.max for x in nodes])))