Adds tril & triu support (#936)

* triu & tril support

* lint and kernel count error

* switched shape indicies

* larger shape tests

* reverted numpy removal until #942 is resolved
This commit is contained in:
Diogo
2023-06-10 01:13:20 -04:00
committed by GitHub
parent 48e9461197
commit 2d4370b487
7 changed files with 30 additions and 6 deletions

View File

@@ -481,6 +481,12 @@ class Tensor:
def sin(self): return mlops.Sin.apply(self)
def cos(self): return ((math.pi/2)-self).sin()
def tan(self): return self.sin() / self.cos()
@staticmethod
def _tri(r:int, c:int, k:int=0) -> Tensor: return Tensor.arange(r).unsqueeze(1).expand(r,c) <= Tensor.arange(c-k, start=-k).unsqueeze(0).expand(r,c)
def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k).where(self, Tensor.zeros_like(self))
def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1).where(Tensor.zeros_like(self), self)
# ***** math functions (unary) *****
def __neg__(self): return 0.0-self