mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-19 02:44:40 -05:00
remove GeNode (#965)
This commit is contained in:
@@ -6,14 +6,13 @@ from tinygrad.helpers import dtypes
|
||||
from tinygrad.ops import Op, ASTRunner, UnaryOps, BinaryOps, FusedOps
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, GeNode, LtNode, SumNode, AndNode
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
|
||||
def int_const(x): return ir.Constant(ir.IntType(64), x)
|
||||
render_llvm = {
|
||||
NumNode: lambda self,ops,ctx: int_const(self.b),
|
||||
MulNode: lambda self,ops,ctx: ctx.mul(self.a.render(ops,ctx), int_const(self.b)),
|
||||
DivNode: lambda self,ops,ctx: ctx.sdiv(self.a.render(ops,ctx), int_const(self.b)),
|
||||
ModNode: lambda self,ops,ctx: ctx.srem(self.a.render(ops,ctx), int_const(self.b)),
|
||||
GeNode: lambda self,ops,ctx: ctx.icmp_signed(">=", self.a.render(ops,ctx), int_const(self.b)),
|
||||
LtNode: lambda self,ops,ctx: ctx.icmp_signed("<", self.a.render(ops,ctx), int_const(self.b)),
|
||||
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.add(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
||||
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.and_(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx))
|
||||
|
||||
@@ -31,6 +31,7 @@ class RawBufferCopyIn(RawBuffer):
|
||||
|
||||
class RawBufferMapped(RawBufferCopyIn):
|
||||
def _buffer(self) -> memoryview: raise NotImplementedError("must be implemented")
|
||||
# NOTE: this metadata prevents the backing buffer from being freed. hack can be removed with PEP688
|
||||
def toCPU(self) -> np.ndarray: return np.frombuffer(self._buffer(), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore
|
||||
def _copyin(self, x:np.ndarray) -> None: np.copyto(self.toCPU(), x.reshape(-1))
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ class Node:
|
||||
def __neg__(self): return self*-1
|
||||
def __add__(self, b:Union[Node, int]): return Variable.sum([self, b if isinstance(b, Node) else Variable.num(b)])
|
||||
def __sub__(self, b:Union[Node, int]): return self+-b
|
||||
def __ge__(self, b:int): return create_node(GeNode(self, b))
|
||||
def __ge__(self, b:int): return create_node(LtNode(-self, -b+1))
|
||||
def __lt__(self, b:int): return create_node(LtNode(self, b))
|
||||
def __mul__(self, b:int):
|
||||
if b == 0: return NumNode(0)
|
||||
@@ -125,16 +125,12 @@ def create_node(ret:Node):
|
||||
return ret
|
||||
|
||||
class OpNode(Node):
|
||||
def __init__(self, a:Node, b:int):
|
||||
def __init__(self, a:Node, b:int):
|
||||
self.a, self.b = a, b
|
||||
self.min, self.max = self.get_bounds()
|
||||
@abstractmethod
|
||||
@abstractmethod
|
||||
def get_bounds(self) -> Tuple[int, int]: pass
|
||||
|
||||
class GeNode(OpNode):
|
||||
def __mul__(self, b: int): return (self.a*b) >= (self.b*b)
|
||||
def __floordiv__(self, b: int, _=False): return (self.a//b) >= (self.b//b)
|
||||
def get_bounds(self) -> Tuple[int, int]: return int(self.a.min >= self.b), int(self.a.max >= self.b)
|
||||
class LtNode(OpNode):
|
||||
def __mul__(self, b: int): return (self.a*b) < (self.b*b)
|
||||
def __floordiv__(self, b: int, _=False): return (self.a//b) < (self.b//b)
|
||||
@@ -148,18 +144,18 @@ class MulNode(OpNode):
|
||||
def __mod__(self, b: int):
|
||||
a = (self.a * (self.b%b))
|
||||
return Node.__mod__(a, b)
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
|
||||
class DivNode(OpNode):
|
||||
def __floordiv__(self, b: int, _=False): return self.a//(self.b*b) # two divs is one div
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
assert self.a.min >= 0
|
||||
return self.a.min//self.b, self.a.max//self.b
|
||||
class ModNode(OpNode):
|
||||
def __floordiv__(self, b: int, factoring_allowed=True):
|
||||
if (self.b % b == 0): return (self.a//b) % (self.b//b) # put the div inside mod
|
||||
return Node.__floordiv__(self, b, factoring_allowed)
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
assert self.a.min >= 0
|
||||
return (0, self.b-1) if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b) else (self.a.min%self.b, self.a.max%self.b)
|
||||
|
||||
@@ -194,7 +190,7 @@ class SumNode(RedNode):
|
||||
if m > 1 and b%m == 0:
|
||||
return (self//m)//(b//m)
|
||||
return Node.__floordiv__(self, b, factoring_allowed)
|
||||
def __mod__(self, b: int):
|
||||
def __mod__(self, b: int):
|
||||
new_nodes = []
|
||||
for x in self.nodes:
|
||||
if isinstance(x, NumNode): new_nodes.append(Variable.num(x.b%b))
|
||||
@@ -202,7 +198,7 @@ class SumNode(RedNode):
|
||||
else: new_nodes.append(x)
|
||||
return Node.__mod__(Variable.sum(new_nodes), b)
|
||||
|
||||
class AndNode(RedNode):
|
||||
class AndNode(RedNode):
|
||||
def __mul__(self, b: int): Variable.ands([x*b for x in self.nodes])
|
||||
def __floordiv__(self, b: int, _=True): return Variable.ands([x//b for x in self.nodes])
|
||||
|
||||
@@ -218,7 +214,6 @@ render_python: Dict[Type, Callable] = {
|
||||
MulNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}*{self.b})",
|
||||
DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
|
||||
ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})",
|
||||
GeNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}>={self.b})",
|
||||
LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{self.b})",
|
||||
SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
|
||||
AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
|
||||
|
||||
Reference in New Issue
Block a user