Update arange to be (start, stop, step) (#1308)

This commit is contained in:
madt2709
2023-07-20 21:27:23 -07:00
committed by GitHub
parent f45013f0a3
commit d2c1e8409a
4 changed files with 10 additions and 8 deletions

View File

@@ -152,7 +152,9 @@ class Tensor:
def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1, **kwargs)
@staticmethod
def arange(stop, start=0, step=1, **kwargs): return Tensor.full((ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step)
def arange(start, stop=None, step=1, **kwargs):
if stop is None: stop, start = start, 0
return Tensor.full((ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step)
@staticmethod
def full_like(tensor, fill_value, dtype:Optional[DType]=None, **kwargs):
@@ -499,7 +501,7 @@ class Tensor:
def tan(self): return self.sin() / self.cos()
@staticmethod
def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(c-k, start=-k, **kwargs).unsqueeze(0).expand(r,c)
def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype).where(self, Tensor.zeros_like(self))
def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype).where(Tensor.zeros_like(self), self)