mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
update Tensor.scatter doc examples (#7924)
same example from torch, i think it's much more useful
This commit is contained in:
@@ -2333,20 +2333,28 @@ class Tensor(SimpleMathTrait):
|
||||
x = x.gather(i, index)
|
||||
return x.cast(self.dtype)
|
||||
|
||||
def scatter(self, dim:int, index:Tensor, src:Union[Tensor, ConstType], reduce:Union[None, Literal['multiply'], Literal['add']] = None) -> Tensor:
|
||||
def scatter(self, dim:int, index:Tensor, src:Union[Tensor, ConstType], reduce:Union[None, Literal['multiply'], Literal['add']]=None) -> Tensor:
|
||||
"""
|
||||
Scatters `src` values along an axis specified by `dim`.
|
||||
Apply `add` or `multiply` reduction operation with `reduce`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([[1, 2], [3, 4]])
|
||||
print(t.numpy())
|
||||
src = Tensor.arange(1, 11).reshape(2, 5)
|
||||
print(src.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.scatter(dim=1, index=Tensor([[0, 0], [1, 0]]), src=9).numpy())
|
||||
index = Tensor([[0, 1, 2, 0]])
|
||||
print(Tensor.zeros(3, 5, dtype=src.dtype).scatter(0, index, src).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.scatter(dim=1, index=Tensor([[0, 0], [1, 0]]), src=Tensor([[3, 3], [9, 9]]), reduce="add").numpy())
|
||||
index = Tensor([[0, 1, 2], [0, 1, 4]])
|
||||
print(Tensor.zeros(3, 5, dtype=src.dtype).scatter(1, index, src).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor.full((2, 4), 2.0).scatter(1, Tensor([[2], [3]]), 1.23, reduce='multiply').numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor.full((2, 4), 2.0).scatter(1, Tensor([[2], [3]]), 1.23, reduce='add').numpy())
|
||||
```
|
||||
"""
|
||||
index, dim = index.to(self.device), self._resolve_dim(dim)
|
||||
|
||||
Reference in New Issue
Block a user