update Tensor.scatter doc examples (#7924)

same example from torch, i think it's much more useful
This commit is contained in:
chenyu
2024-11-27 11:42:36 -05:00
committed by GitHub
parent cea5853cfa
commit 57262c8e34

View File

@@ -2333,20 +2333,28 @@ class Tensor(SimpleMathTrait):
x = x.gather(i, index)
return x.cast(self.dtype)
def scatter(self, dim:int, index:Tensor, src:Union[Tensor, ConstType], reduce:Union[None, Literal['multiply'], Literal['add']] = None) -> Tensor:
def scatter(self, dim:int, index:Tensor, src:Union[Tensor, ConstType], reduce:Union[None, Literal['multiply'], Literal['add']]=None) -> Tensor:
"""
Scatters `src` values along an axis specified by `dim`.
Apply `add` or `multiply` reduction operation with `reduce`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 2], [3, 4]])
print(t.numpy())
src = Tensor.arange(1, 11).reshape(2, 5)
print(src.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.scatter(dim=1, index=Tensor([[0, 0], [1, 0]]), src=9).numpy())
index = Tensor([[0, 1, 2, 0]])
print(Tensor.zeros(3, 5, dtype=src.dtype).scatter(0, index, src).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.scatter(dim=1, index=Tensor([[0, 0], [1, 0]]), src=Tensor([[3, 3], [9, 9]]), reduce="add").numpy())
index = Tensor([[0, 1, 2], [0, 1, 4]])
print(Tensor.zeros(3, 5, dtype=src.dtype).scatter(1, index, src).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 4), 2.0).scatter(1, Tensor([[2], [3]]), 1.23, reduce='multiply').numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 4), 2.0).scatter(1, Tensor([[2], [3]]), 1.23, reduce='add').numpy())
```
"""
index, dim = index.to(self.device), self._resolve_dim(dim)