From c62c64f0b71920a4d62f64ea94a04f9ada5633b5 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 9 Jun 2023 21:48:56 -0700 Subject: [PATCH] remove GeNode (#965) --- test/unit/test_symbolic.py | 4 ++-- tinygrad/codegen/llvmir.py | 3 +-- tinygrad/runtime/lib.py | 1 + tinygrad/shape/symbolic.py | 21 ++++++++------------- 4 files changed, 12 insertions(+), 17 deletions(-) diff --git a/test/unit/test_symbolic.py b/test/unit/test_symbolic.py index 1e0def9bff..eed3fc2c0c 100644 --- a/test/unit/test_symbolic.py +++ b/test/unit/test_symbolic.py @@ -11,8 +11,8 @@ class TestSymbolic(unittest.TestCase): def test_ge(self): self.helper_test_variable(Variable("a", 3, 8)>=77, 0, 0, "0") self.helper_test_variable(Variable("a", 3, 8)>=9, 0, 0, "0") - self.helper_test_variable(Variable("a", 3, 8)>=8, 0, 1, "(a>=8)") - self.helper_test_variable(Variable("a", 3, 8)>=4, 0, 1, "(a>=4)") + self.helper_test_variable(Variable("a", 3, 8)>=8, 0, 1, "((a*-1)<-7)") + self.helper_test_variable(Variable("a", 3, 8)>=4, 0, 1, "((a*-1)<-3)") self.helper_test_variable(Variable("a", 3, 8)>=3, 1, 1, "1") self.helper_test_variable(Variable("a", 3, 8)>=2, 1, 1, "1") diff --git a/tinygrad/codegen/llvmir.py b/tinygrad/codegen/llvmir.py index 20a200e221..4fc21b545e 100644 --- a/tinygrad/codegen/llvmir.py +++ b/tinygrad/codegen/llvmir.py @@ -6,14 +6,13 @@ from tinygrad.helpers import dtypes from tinygrad.ops import Op, ASTRunner, UnaryOps, BinaryOps, FusedOps from tinygrad.lazy import LazyBuffer -from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, GeNode, LtNode, SumNode, AndNode +from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode def int_const(x): return ir.Constant(ir.IntType(64), x) render_llvm = { NumNode: lambda self,ops,ctx: int_const(self.b), MulNode: lambda self,ops,ctx: ctx.mul(self.a.render(ops,ctx), int_const(self.b)), DivNode: lambda self,ops,ctx: ctx.sdiv(self.a.render(ops,ctx), int_const(self.b)), ModNode: lambda self,ops,ctx: ctx.srem(self.a.render(ops,ctx), int_const(self.b)), - GeNode: lambda self,ops,ctx: ctx.icmp_signed(">=", self.a.render(ops,ctx), int_const(self.b)), LtNode: lambda self,ops,ctx: ctx.icmp_signed("<", self.a.render(ops,ctx), int_const(self.b)), SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.add(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)), AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.and_(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)) diff --git a/tinygrad/runtime/lib.py b/tinygrad/runtime/lib.py index cbc4e73adc..32aa00a380 100644 --- a/tinygrad/runtime/lib.py +++ b/tinygrad/runtime/lib.py @@ -31,6 +31,7 @@ class RawBufferCopyIn(RawBuffer): class RawBufferMapped(RawBufferCopyIn): def _buffer(self) -> memoryview: raise NotImplementedError("must be implemented") + # NOTE: this metadata prevents the backing buffer from being freed. hack can be removed with PEP688 def toCPU(self) -> np.ndarray: return np.frombuffer(self._buffer(), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore def _copyin(self, x:np.ndarray) -> None: np.copyto(self.toCPU(), x.reshape(-1)) diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index bcde187fbb..f0178791e9 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -25,7 +25,7 @@ class Node: def __neg__(self): return self*-1 def __add__(self, b:Union[Node, int]): return Variable.sum([self, b if isinstance(b, Node) else Variable.num(b)]) def __sub__(self, b:Union[Node, int]): return self+-b - def __ge__(self, b:int): return create_node(GeNode(self, b)) + def __ge__(self, b:int): return create_node(LtNode(-self, -b+1)) def __lt__(self, b:int): return create_node(LtNode(self, b)) def __mul__(self, b:int): if b == 0: return NumNode(0) @@ -125,16 +125,12 @@ def create_node(ret:Node): return ret class OpNode(Node): - def __init__(self, a:Node, b:int): + def __init__(self, a:Node, b:int): self.a, self.b = a, b self.min, self.max = self.get_bounds() - @abstractmethod + @abstractmethod def get_bounds(self) -> Tuple[int, int]: pass -class GeNode(OpNode): - def __mul__(self, b: int): return (self.a*b) >= (self.b*b) - def __floordiv__(self, b: int, _=False): return (self.a//b) >= (self.b//b) - def get_bounds(self) -> Tuple[int, int]: return int(self.a.min >= self.b), int(self.a.max >= self.b) class LtNode(OpNode): def __mul__(self, b: int): return (self.a*b) < (self.b*b) def __floordiv__(self, b: int, _=False): return (self.a//b) < (self.b//b) @@ -148,18 +144,18 @@ class MulNode(OpNode): def __mod__(self, b: int): a = (self.a * (self.b%b)) return Node.__mod__(a, b) - def get_bounds(self) -> Tuple[int, int]: + def get_bounds(self) -> Tuple[int, int]: return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b) class DivNode(OpNode): def __floordiv__(self, b: int, _=False): return self.a//(self.b*b) # two divs is one div - def get_bounds(self) -> Tuple[int, int]: + def get_bounds(self) -> Tuple[int, int]: assert self.a.min >= 0 return self.a.min//self.b, self.a.max//self.b class ModNode(OpNode): def __floordiv__(self, b: int, factoring_allowed=True): if (self.b % b == 0): return (self.a//b) % (self.b//b) # put the div inside mod return Node.__floordiv__(self, b, factoring_allowed) - def get_bounds(self) -> Tuple[int, int]: + def get_bounds(self) -> Tuple[int, int]: assert self.a.min >= 0 return (0, self.b-1) if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b) else (self.a.min%self.b, self.a.max%self.b) @@ -194,7 +190,7 @@ class SumNode(RedNode): if m > 1 and b%m == 0: return (self//m)//(b//m) return Node.__floordiv__(self, b, factoring_allowed) - def __mod__(self, b: int): + def __mod__(self, b: int): new_nodes = [] for x in self.nodes: if isinstance(x, NumNode): new_nodes.append(Variable.num(x.b%b)) @@ -202,7 +198,7 @@ class SumNode(RedNode): else: new_nodes.append(x) return Node.__mod__(Variable.sum(new_nodes), b) -class AndNode(RedNode): +class AndNode(RedNode): def __mul__(self, b: int): Variable.ands([x*b for x in self.nodes]) def __floordiv__(self, b: int, _=True): return Variable.ands([x//b for x in self.nodes]) @@ -218,7 +214,6 @@ render_python: Dict[Type, Callable] = { MulNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}*{self.b})", DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})", ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})", - GeNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}>={self.b})", LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{self.b})", SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})", AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"