mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 02:21:40 -05:00
Update arange to be (start, stop, step) (#1308)
This commit is contained in:
@@ -152,7 +152,9 @@ class Tensor:
|
||||
def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def arange(stop, start=0, step=1, **kwargs): return Tensor.full((ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step)
|
||||
def arange(start, stop=None, step=1, **kwargs):
|
||||
if stop is None: stop, start = start, 0
|
||||
return Tensor.full((ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step)
|
||||
|
||||
@staticmethod
|
||||
def full_like(tensor, fill_value, dtype:Optional[DType]=None, **kwargs):
|
||||
@@ -499,7 +501,7 @@ class Tensor:
|
||||
def tan(self): return self.sin() / self.cos()
|
||||
|
||||
@staticmethod
|
||||
def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(c-k, start=-k, **kwargs).unsqueeze(0).expand(r,c)
|
||||
def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
|
||||
def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype).where(self, Tensor.zeros_like(self))
|
||||
def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype).where(Tensor.zeros_like(self), self)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user