mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 06:48:22 -05:00
fix tril
This commit is contained in:
@@ -659,7 +659,7 @@ class Tensor:
|
||||
@staticmethod
|
||||
def _tri(r:sint, c:sint, k:int=0, **kwargs) -> Tensor:
|
||||
assert all_int((r,c)), "does not support symbolic"
|
||||
return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
|
||||
return (Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)).float()
|
||||
def triu(self, k:int=0) -> Tensor:
|
||||
return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype, device=self.device).where(self, Tensor.zeros_like(self))
|
||||
def tril(self, k:int=0) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user