From 5eef59d7327f8dd10d2d39ac8bbce66d4de20a34 Mon Sep 17 00:00:00 2001 From: geohotstan <135171913+geohotstan@users.noreply.github.com> Date: Tue, 12 Nov 2024 10:29:36 +0800 Subject: [PATCH] add Tensor.linspace (#7609) * add linspace * shave off tests and forgot to add to docs crap * WHOOPS * better tests --- docs/tensor/creation.md | 1 + test/test_ops.py | 10 ++++++++++ tinygrad/tensor.py | 20 ++++++++++++++++++++ 3 files changed, 31 insertions(+) diff --git a/docs/tensor/creation.md b/docs/tensor/creation.md index 897c29ad38..722d1a41e0 100644 --- a/docs/tensor/creation.md +++ b/docs/tensor/creation.md @@ -5,6 +5,7 @@ ::: tinygrad.Tensor.ones ::: tinygrad.Tensor.full ::: tinygrad.Tensor.arange +::: tinygrad.Tensor.linspace ::: tinygrad.Tensor.eye ::: tinygrad.Tensor.full_like ::: tinygrad.Tensor.zeros_like diff --git a/test/test_ops.py b/test/test_ops.py index ca1a95af62..64daff191e 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -233,6 +233,16 @@ class TestOps(unittest.TestCase): def test_arange_4096(self): helper_test_op([], lambda: torch.arange(4096, dtype=torch.int32), lambda: Tensor.arange(4096), forward_only=True) + def test_linspace(self): + helper_test_op([], lambda: torch.linspace(5, 10, 3), lambda: Tensor.linspace(5, 10, 3), forward_only=True) + helper_test_op([], lambda: torch.linspace(5, 10, 1), lambda: Tensor.linspace(5, 10, 1), forward_only=True) + helper_test_op([], lambda: torch.linspace(5, 10, 0), lambda: Tensor.linspace(5, 10, 0), forward_only=True) + helper_test_op([], lambda: torch.linspace(5, 10, 30), lambda: Tensor.linspace(5, 10, 30), forward_only=True) + helper_test_op([], lambda: torch.linspace(-5.5, 5.5, 10), lambda: Tensor.linspace(-5.5, 5.5, 10), forward_only=True) + helper_test_op([], lambda: torch.linspace(5.5, -5.5, 10), lambda: Tensor.linspace(5.5, -5.5, 10), forward_only=True) + helper_test_op([], lambda: torch.linspace(5, 10, 3, dtype=torch.int32), lambda: Tensor.linspace(5, 10, 3, dtype=dtypes.int32), forward_only=True) + self.helper_test_exception([], lambda: torch.linspace(1, 2, -1), lambda: Tensor.linspace(1, 2, -1), expected=(RuntimeError, ValueError)) + def test_sum_fake(self): helper_test_op([(256, 1)], lambda x: x.sum(axis=1)) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index cc76bad61d..15df5e5bfe 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -612,6 +612,26 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method if (output_len:=ceildiv(stop-start, step)) <= 0: return Tensor([], dtype=dtype, **kwargs) return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype) + @staticmethod + def linspace(start:Union[int, float], stop:Union[int, float], steps:int, **kwargs) -> Tensor: + """ + Returns a 1-D tensor of `steps` evenly spaced values from `start` to `stop`, inclusive. + + You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor. + Additionally, all other keyword arguments are passed to the constructor of the tensor. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor.linspace(0, 10, 5).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor.linspace(-1, 1, 5).numpy()) + ``` + """ + if steps < 0: raise ValueError("number of steps must be non-negative") + dtype = kwargs.pop("dtype", dtypes.default_float) + if steps == 1: return Tensor([start], dtype=dtype, **kwargs) + return (start + Tensor.arange(steps, **kwargs) * ((stop - start) / (steps - 1))).cast(dtype) + @staticmethod def eye(n:int, m:Optional[int]=None, **kwargs) -> Tensor: """