This commit is contained in:
qazal
2023-12-22 22:38:23 +02:00
parent f6b6024350
commit 7ebac1018f

View File

@@ -659,7 +659,7 @@ class Tensor:
@staticmethod
def _tri(r:sint, c:sint, k:int=0, **kwargs) -> Tensor:
assert all_int((r,c)), "does not support symbolic"
return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
return (Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)).float()
def triu(self, k:int=0) -> Tensor:
return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype, device=self.device).where(self, Tensor.zeros_like(self))
def tril(self, k:int=0) -> Tensor: