remove GeNode (#965)

This commit is contained in:
George Hotz
2023-06-09 21:48:56 -07:00
committed by GitHub
parent 2c324d0685
commit c62c64f0b7
4 changed files with 12 additions and 17 deletions

View File

@@ -6,14 +6,13 @@ from tinygrad.helpers import dtypes
from tinygrad.ops import Op, ASTRunner, UnaryOps, BinaryOps, FusedOps
from tinygrad.lazy import LazyBuffer
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, GeNode, LtNode, SumNode, AndNode
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
def int_const(x): return ir.Constant(ir.IntType(64), x)
render_llvm = {
NumNode: lambda self,ops,ctx: int_const(self.b),
MulNode: lambda self,ops,ctx: ctx.mul(self.a.render(ops,ctx), int_const(self.b)),
DivNode: lambda self,ops,ctx: ctx.sdiv(self.a.render(ops,ctx), int_const(self.b)),
ModNode: lambda self,ops,ctx: ctx.srem(self.a.render(ops,ctx), int_const(self.b)),
GeNode: lambda self,ops,ctx: ctx.icmp_signed(">=", self.a.render(ops,ctx), int_const(self.b)),
LtNode: lambda self,ops,ctx: ctx.icmp_signed("<", self.a.render(ops,ctx), int_const(self.b)),
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.add(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.and_(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx))

View File

@@ -31,6 +31,7 @@ class RawBufferCopyIn(RawBuffer):
class RawBufferMapped(RawBufferCopyIn):
def _buffer(self) -> memoryview: raise NotImplementedError("must be implemented")
# NOTE: this metadata prevents the backing buffer from being freed. hack can be removed with PEP688
def toCPU(self) -> np.ndarray: return np.frombuffer(self._buffer(), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore
def _copyin(self, x:np.ndarray) -> None: np.copyto(self.toCPU(), x.reshape(-1))

View File

@@ -25,7 +25,7 @@ class Node:
def __neg__(self): return self*-1
def __add__(self, b:Union[Node, int]): return Variable.sum([self, b if isinstance(b, Node) else Variable.num(b)])
def __sub__(self, b:Union[Node, int]): return self+-b
def __ge__(self, b:int): return create_node(GeNode(self, b))
def __ge__(self, b:int): return create_node(LtNode(-self, -b+1))
def __lt__(self, b:int): return create_node(LtNode(self, b))
def __mul__(self, b:int):
if b == 0: return NumNode(0)
@@ -125,16 +125,12 @@ def create_node(ret:Node):
return ret
class OpNode(Node):
def __init__(self, a:Node, b:int):
def __init__(self, a:Node, b:int):
self.a, self.b = a, b
self.min, self.max = self.get_bounds()
@abstractmethod
@abstractmethod
def get_bounds(self) -> Tuple[int, int]: pass
class GeNode(OpNode):
def __mul__(self, b: int): return (self.a*b) >= (self.b*b)
def __floordiv__(self, b: int, _=False): return (self.a//b) >= (self.b//b)
def get_bounds(self) -> Tuple[int, int]: return int(self.a.min >= self.b), int(self.a.max >= self.b)
class LtNode(OpNode):
def __mul__(self, b: int): return (self.a*b) < (self.b*b)
def __floordiv__(self, b: int, _=False): return (self.a//b) < (self.b//b)
@@ -148,18 +144,18 @@ class MulNode(OpNode):
def __mod__(self, b: int):
a = (self.a * (self.b%b))
return Node.__mod__(a, b)
def get_bounds(self) -> Tuple[int, int]:
def get_bounds(self) -> Tuple[int, int]:
return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
class DivNode(OpNode):
def __floordiv__(self, b: int, _=False): return self.a//(self.b*b) # two divs is one div
def get_bounds(self) -> Tuple[int, int]:
def get_bounds(self) -> Tuple[int, int]:
assert self.a.min >= 0
return self.a.min//self.b, self.a.max//self.b
class ModNode(OpNode):
def __floordiv__(self, b: int, factoring_allowed=True):
if (self.b % b == 0): return (self.a//b) % (self.b//b) # put the div inside mod
return Node.__floordiv__(self, b, factoring_allowed)
def get_bounds(self) -> Tuple[int, int]:
def get_bounds(self) -> Tuple[int, int]:
assert self.a.min >= 0
return (0, self.b-1) if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b) else (self.a.min%self.b, self.a.max%self.b)
@@ -194,7 +190,7 @@ class SumNode(RedNode):
if m > 1 and b%m == 0:
return (self//m)//(b//m)
return Node.__floordiv__(self, b, factoring_allowed)
def __mod__(self, b: int):
def __mod__(self, b: int):
new_nodes = []
for x in self.nodes:
if isinstance(x, NumNode): new_nodes.append(Variable.num(x.b%b))
@@ -202,7 +198,7 @@ class SumNode(RedNode):
else: new_nodes.append(x)
return Node.__mod__(Variable.sum(new_nodes), b)
class AndNode(RedNode):
class AndNode(RedNode):
def __mul__(self, b: int): Variable.ands([x*b for x in self.nodes])
def __floordiv__(self, b: int, _=True): return Variable.ands([x//b for x in self.nodes])
@@ -218,7 +214,6 @@ render_python: Dict[Type, Callable] = {
MulNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}*{self.b})",
DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})",
GeNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}>={self.b})",
LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{self.b})",
SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"