mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 18:35:12 -05:00
Adds tril & triu support (#936)
* triu & tril support * lint and kernel count error * switched shape indicies * larger shape tests * reverted numpy removal until #942 is resolved
This commit is contained in:
@@ -481,6 +481,12 @@ class Tensor:
|
||||
def sin(self): return mlops.Sin.apply(self)
|
||||
def cos(self): return ((math.pi/2)-self).sin()
|
||||
def tan(self): return self.sin() / self.cos()
|
||||
|
||||
@staticmethod
|
||||
def _tri(r:int, c:int, k:int=0) -> Tensor: return Tensor.arange(r).unsqueeze(1).expand(r,c) <= Tensor.arange(c-k, start=-k).unsqueeze(0).expand(r,c)
|
||||
def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k).where(self, Tensor.zeros_like(self))
|
||||
def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1).where(Tensor.zeros_like(self), self)
|
||||
|
||||
# ***** math functions (unary) *****
|
||||
|
||||
def __neg__(self): return 0.0-self
|
||||
|
||||
Reference in New Issue
Block a user