diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 38dde06b5e..462232d7f8 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -659,7 +659,7 @@ class Tensor: @staticmethod def _tri(r:sint, c:sint, k:int=0, **kwargs) -> Tensor: assert all_int((r,c)), "does not support symbolic" - return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c) + return (Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)).float() def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype, device=self.device).where(self, Tensor.zeros_like(self)) def tril(self, k:int=0) -> Tensor: