From 57262c8e34d9a4d0c736dfe3a71ff6b1f57ea7b5 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 27 Nov 2024 11:42:36 -0500 Subject: [PATCH] update Tensor.scatter doc examples (#7924) same example from torch, i think it's much more useful --- tinygrad/tensor.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 9042be4d5a..2ee9f93a81 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2333,20 +2333,28 @@ class Tensor(SimpleMathTrait): x = x.gather(i, index) return x.cast(self.dtype) - def scatter(self, dim:int, index:Tensor, src:Union[Tensor, ConstType], reduce:Union[None, Literal['multiply'], Literal['add']] = None) -> Tensor: + def scatter(self, dim:int, index:Tensor, src:Union[Tensor, ConstType], reduce:Union[None, Literal['multiply'], Literal['add']]=None) -> Tensor: """ Scatters `src` values along an axis specified by `dim`. Apply `add` or `multiply` reduction operation with `reduce`. ```python exec="true" source="above" session="tensor" result="python" - t = Tensor([[1, 2], [3, 4]]) - print(t.numpy()) + src = Tensor.arange(1, 11).reshape(2, 5) + print(src.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" - print(t.scatter(dim=1, index=Tensor([[0, 0], [1, 0]]), src=9).numpy()) + index = Tensor([[0, 1, 2, 0]]) + print(Tensor.zeros(3, 5, dtype=src.dtype).scatter(0, index, src).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" - print(t.scatter(dim=1, index=Tensor([[0, 0], [1, 0]]), src=Tensor([[3, 3], [9, 9]]), reduce="add").numpy()) + index = Tensor([[0, 1, 2], [0, 1, 4]]) + print(Tensor.zeros(3, 5, dtype=src.dtype).scatter(1, index, src).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor.full((2, 4), 2.0).scatter(1, Tensor([[2], [3]]), 1.23, reduce='multiply').numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor.full((2, 4), 2.0).scatter(1, Tensor([[2], [3]]), 1.23, reduce='add').numpy()) ``` """ index, dim = index.to(self.device), self._resolve_dim(dim)