add Tensor.linspace (#7609)

* add linspace

* shave off tests and forgot to add to docs crap

* WHOOPS

* better tests
This commit is contained in:
geohotstan
2024-11-12 10:29:36 +08:00
committed by GitHub
parent 99f29e50b2
commit 5eef59d732
3 changed files with 31 additions and 0 deletions

View File

@@ -5,6 +5,7 @@
::: tinygrad.Tensor.ones
::: tinygrad.Tensor.full
::: tinygrad.Tensor.arange
::: tinygrad.Tensor.linspace
::: tinygrad.Tensor.eye
::: tinygrad.Tensor.full_like
::: tinygrad.Tensor.zeros_like

View File

@@ -233,6 +233,16 @@ class TestOps(unittest.TestCase):
def test_arange_4096(self):
helper_test_op([], lambda: torch.arange(4096, dtype=torch.int32), lambda: Tensor.arange(4096), forward_only=True)
def test_linspace(self):
helper_test_op([], lambda: torch.linspace(5, 10, 3), lambda: Tensor.linspace(5, 10, 3), forward_only=True)
helper_test_op([], lambda: torch.linspace(5, 10, 1), lambda: Tensor.linspace(5, 10, 1), forward_only=True)
helper_test_op([], lambda: torch.linspace(5, 10, 0), lambda: Tensor.linspace(5, 10, 0), forward_only=True)
helper_test_op([], lambda: torch.linspace(5, 10, 30), lambda: Tensor.linspace(5, 10, 30), forward_only=True)
helper_test_op([], lambda: torch.linspace(-5.5, 5.5, 10), lambda: Tensor.linspace(-5.5, 5.5, 10), forward_only=True)
helper_test_op([], lambda: torch.linspace(5.5, -5.5, 10), lambda: Tensor.linspace(5.5, -5.5, 10), forward_only=True)
helper_test_op([], lambda: torch.linspace(5, 10, 3, dtype=torch.int32), lambda: Tensor.linspace(5, 10, 3, dtype=dtypes.int32), forward_only=True)
self.helper_test_exception([], lambda: torch.linspace(1, 2, -1), lambda: Tensor.linspace(1, 2, -1), expected=(RuntimeError, ValueError))
def test_sum_fake(self):
helper_test_op([(256, 1)], lambda x: x.sum(axis=1))

View File

@@ -612,6 +612,26 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
if (output_len:=ceildiv(stop-start, step)) <= 0: return Tensor([], dtype=dtype, **kwargs)
return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
@staticmethod
def linspace(start:Union[int, float], stop:Union[int, float], steps:int, **kwargs) -> Tensor:
"""
Returns a 1-D tensor of `steps` evenly spaced values from `start` to `stop`, inclusive.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.linspace(0, 10, 5).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.linspace(-1, 1, 5).numpy())
```
"""
if steps < 0: raise ValueError("number of steps must be non-negative")
dtype = kwargs.pop("dtype", dtypes.default_float)
if steps == 1: return Tensor([start], dtype=dtype, **kwargs)
return (start + Tensor.arange(steps, **kwargs) * ((stop - start) / (steps - 1))).cast(dtype)
@staticmethod
def eye(n:int, m:Optional[int]=None, **kwargs) -> Tensor:
"""