Add Tensor.scatter_reduce (#8947)

* pytorch scatter -> scatter_reduce

* WIP scatter_reduce implementation

* _pre_scatter return type hint

* split out src, mask to satisfy linter

* Add src cast back in

* dict of lambdas instead of ifs

* sum and prod reduction ops with include_self

* add reduce arg error message

* add amax and amin reduction ops

* Fix include_self for higher dims

* Simplify

* Simplify amax and amin too

* Pull include_self logic out into _inv_mask function

* reduce arg cannot be None for scatter_reduce

* Fix self-mask issue

* Add mean reduce op

* Add tests

* any() not needed here

* remove comment

* End support for Tensor src with reduce arg in tinygrad scatter

* Process index, dim inside actual functions

* Add scatter_reduce to onnx

* Add excluded onnx ScatterElements reduction tests back in

* Save 2 lines on the mask helpers

* Update docs

* Add include_self=False tests

* cleanup

* Remove unneeded helper function

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
Josh Moore
2025-02-13 09:08:54 -05:00
committed by GitHub
parent 2b9ce1235a
commit 1f9d2442b9
5 changed files with 131 additions and 27 deletions

View File

@@ -32,6 +32,7 @@
::: tinygrad.Tensor.tril
::: tinygrad.Tensor.interpolate
::: tinygrad.Tensor.scatter
::: tinygrad.Tensor.scatter_reduce
## Neural Network (functional)

View File

@@ -691,9 +691,10 @@ def get_onnx_ops():
else: raise NotImplementedError("reduction doesn't support max or min")
return x
def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul"]="none"):
def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul", "min", "max"]="none"):
indices = (indices < 0).where(x.shape[axis], 0) + indices
return x.scatter(axis, indices, updates, {"none":None, "mul": "multiply"}.get(reduction, reduction))
if reduction == "none": return x.scatter(axis, indices, updates)
return x.scatter_reduce(axis, indices, updates, {"add": "sum", "mul": "prod", "min": "amin", "max": "amax"}.get(reduction))
def GatherElements(x:Tensor, indices:Tensor, axis:int):
indices = (indices < 0).where(x.shape[axis], 0) + indices
return x.gather(axis, indices)

View File

@@ -175,9 +175,7 @@ backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_value_only_mapping_cp
backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_mapping_cpu') # bad data type string
backend_test.exclude('test_group_normalization_*') # numerical inaccuracy problem. Current Group Normalization OP fails test
backend_test.exclude('test_scatter_elements_with_reduction_min_cpu') # min not yet supported
backend_test.exclude('test_scatternd_min_cpu') # min not yet supported
backend_test.exclude('test_scatter_elements_with_reduction_max_cpu') # max not yet supported
backend_test.exclude('test_scatternd_max_cpu') # max not yet supported
if Device.DEFAULT in ['GPU', 'METAL']:

View File

@@ -2572,12 +2572,6 @@ class TestOps(unittest.TestCase):
vals=[[1.,2.,3.,4.], [1.,0.]])
def test_scatter_add(self):
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
for dim in (0,1,2,-1,-2,-3):
helper_test_op([(4,5,6), (4,5,6)], lambda x,src: x.scatter(dim=dim, index=b, src=src, reduce="add"),
lambda x,src: x.scatter(dim=dim, index=a, src=src, reduce="add"), forward_only=True)
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
helper_test_op([(4,5,6)], lambda x: x.scatter(dim=1, index=b, value=float("inf"), reduce="add"),
@@ -2592,10 +2586,6 @@ class TestOps(unittest.TestCase):
def test_scatter_mul(self):
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
for dim in (0,1,2,-1,-2,-3):
helper_test_op([(4,5,6), (4,5,6)], lambda x,src: x.scatter(dim=dim, index=b, src=src, reduce="multiply"),
lambda x,src: x.scatter(dim=dim, index=a, src=src, reduce="multiply"), forward_only=True)
helper_test_op([(4,5,6)], lambda x: x.scatter(dim=1, index=b, value=float("inf"), reduce="multiply"),
lambda x: x.scatter(dim=1, index=a, src=float("inf"), reduce="multiply"), forward_only=True)
@@ -2605,10 +2595,74 @@ class TestOps(unittest.TestCase):
lambda x: x.scatter(1, b, float("nan"), reduce="multiply"),
lambda x: x.scatter(1, a, float("nan"), reduce="multiply"), forward_only=True,)
def test_scatter_no_reduce_tensor_src(self):
with self.assertRaises(TypeError):
Tensor.ones(4).scatter(dim=1, index=Tensor([0]), src=Tensor.ones(4), reduce="add")
def test_scatter_reduce_sum(self):
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
for dim in (0,1,2,-1,-2,-3):
helper_test_op([(4,5,6), (4,5,6)],
lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce="sum"),
lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce="sum"), forward_only=True)
helper_test_op([(4,5,6), (4,5,6)],
lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce="sum", include_self=False),
lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce="sum", include_self=False), forward_only=True)
def test_scatter_reduce_prod(self):
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
for dim in (0,1,2,-1,-2,-3):
helper_test_op([(4,5,6), (4,5,6)],
lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce="prod"),
lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce="prod"), forward_only=True)
helper_test_op([(4,5,6), (4,5,6)],
lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce="prod", include_self=False),
lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce="prod", include_self=False), forward_only=True)
x = Tensor.zeros([4,5,6]).float()
y = torch.zeros([4,5,6]).float()
helper_test_op([(4,5,6)], lambda src: y.scatter(dim=1, index=b, src=src, reduce="multiply"),
lambda src: x.scatter(dim=1, index=a, src=src, reduce="multiply"), forward_only=True)
helper_test_op([(4,5,6)],
lambda src: y.scatter_reduce(dim=1, index=b, src=src, reduce="prod"),
lambda src: x.scatter_reduce(dim=1, index=a, src=src, reduce="prod"), forward_only=True)
def test_scatter_reduce_mean(self):
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
for dim in (0,1,2,-1,-2,-3):
helper_test_op([(4,5,6), (4,5,6)],
lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce="mean"),
lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce="mean"), forward_only=True)
helper_test_op([(4,5,6), (4,5,6)],
lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce="mean", include_self=False),
lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce="mean", include_self=False), forward_only=True)
def test_scatter_reduce_amax(self):
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
for dim in (0,1,2,-1,-2,-3):
helper_test_op([(4,5,6), (4,5,6)],
lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce="amax"),
lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce="amax"), forward_only=True)
helper_test_op([(4,5,6), (4,5,6)],
lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce="amax", include_self=False),
lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce="amax", include_self=False), forward_only=True)
def test_scatter_reduce_amin(self):
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
for dim in (0,1,2,-1,-2,-3):
helper_test_op([(4,5,6), (4,5,6)],
lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce="amin"),
lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce="amin"), forward_only=True)
helper_test_op([(4,5,6), (4,5,6)],
lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce="amin", include_self=False),
lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce="amin", include_self=False), forward_only=True)
def test_scatter_reduce_invalid_reduce_op(self):
with self.assertRaises(TypeError):
Tensor.ones(4).scatter_reduce(dim=0, index=Tensor([0]), src=Tensor.ones(4), reduce="INVALID")
def test_scaled_dot_product_attention(self):
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], torch.nn.functional.scaled_dot_product_attention, Tensor.scaled_dot_product_attention)

View File

@@ -2430,11 +2430,26 @@ class Tensor(SimpleMathTrait):
x = x.gather(i, index)
return x.cast(self.dtype)
def _pre_scatter(self, dim:int, index:Tensor, src:Tensor) -> tuple[Tensor, Tensor]:
assert index.ndim == self.ndim == src.ndim, f"self.ndim, index.ndim and src.dim must all equal, {self.ndim=} {index.ndim=} {src.ndim=}"
assert all((d == dim or self_ >= index_) and src_ >= index_ for d,(self_,index_,src_) in enumerate(zip(self.shape, index.shape, src.shape))), \
f"All dimensions of {index.shape=} should be <= to all dimensions of {src.shape=} and all dimensions except dimension {dim} of {self.shape=}"
# shrink src to index shape to shrink away the unused values
src = src.shrink(tuple((0,s) for s in index.shape))
# prepare src and mask for reduce with respect to dim
src = src.unsqueeze(-1).expand(*src.shape, self.shape[dim]).transpose(-1, dim)
mask = index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).transpose(-1, dim)
# pad src and mask to self.shape so that reduce can be done with padded values as no-ops
src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask))
return src, mask
def scatter(self, dim:int, index:Tensor, src:Union[Tensor, ConstType], reduce:Union[None, Literal['multiply'], Literal['add']]=None) -> Tensor:
"""
Scatters `src` values along an axis specified by `dim`.
Apply `add` or `multiply` reduction operation with `reduce`.
NOTE: To use the `reduce` argument with a Tensor `src`, see `Tensor.scatter_reduce`.
```python exec="true" source="above" session="tensor" result="python"
src = Tensor.arange(1, 11).reshape(2, 5)
print(src.numpy())
@@ -2455,22 +2470,57 @@ class Tensor(SimpleMathTrait):
```
"""
if reduce not in {None, "add", "multiply"}: raise TypeError(f"{reduce=} must be one of None, 'multiply', or 'add'")
index, dim = index.to(self.device), self._resolve_dim(dim)
if reduce and isinstance(src, Tensor): raise TypeError("Tensor src is not supported with reduce arg. see scatter_reduce")
src = src.cast(self.dtype) if isinstance(src, Tensor) else Tensor(src, device=self.device, dtype=self.dtype)._broadcast_to(index.shape)
assert index.ndim == self.ndim == src.ndim, f"self.ndim, index.ndim and src.dim must all equal, {self.ndim=} {index.ndim=} {src.ndim=}"
assert all((d == dim or self_ >= index_) and src_ >= index_ for d,(self_,index_,src_) in enumerate(zip(self.shape, index.shape, src.shape))), \
f"All dimensions of {index.shape=} should be <= to all dimensions of {src.shape=} and all dimensions except dimension {dim} of {self.shape=}"
# shrink src to index shape to shrink away the unused values
src = src.shrink(tuple((0,s) for s in index.shape))
# prepare src and mask for reduce with respect to dim
src = src.unsqueeze(-1).expand(*src.shape, self.shape[dim]).transpose(-1, dim)
mask = index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).transpose(-1, dim)
# pad src and mask to self.shape so that reduce can be done with padded values as no-ops
src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask))
index, dim = index.to(self.device), self._resolve_dim(dim)
src, mask = self._pre_scatter(dim, index, src)
if reduce == "add": return mask.where(src, 0).sum(-1, acc_dtype=self.dtype) + self
if reduce == "multiply": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype) * self
return _masked_setitem(self, src, mask, (-1,))
def scatter_reduce(self, dim:int, index:Tensor, src:Tensor, reduce:Literal["sum", "prod", "mean", "amax", "amin"],
include_self:bool=True) -> Tensor:
"""
Scatters `src` values along an axis specified by `dim`.
Apply `"sum"`, `"prod"`, `"mean"`, `"amax"`, or `"amin"` reduction operations with `reduce`.
Set `include_self=False` to exclude values in the `self` Tensor from the reduction.
```python exec="true" source="above" session="tensor" result="python"
src = Tensor.arange(1, 11).cast(dtypes.float).reshape(2, 5)
print(src.numpy())
index = Tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]])
print(index.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='sum').numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='prod').numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='mean', include_self=False).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([[-10, 20, 0, 5, 10]], dtype=src.dtype).scatter_reduce(0, index, src, reduce='amax').numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([[-10, 20, 0, 5, 10]], dtype=src.dtype).scatter_reduce(0, index, src, reduce='amin').numpy())
```
"""
src = src.cast(self.dtype)
index, dim = index.to(self.device), self._resolve_dim(dim)
src, mask = self._pre_scatter(dim, index, src)
def _inv_mask(a:Union[Tensor, ConstType], b:Union[Tensor, ConstType]) -> Tensor: return mask.any(-1).logical_not().where(a, b)
if reduce == "sum": return mask.where(src, 0).sum(-1, acc_dtype=self.dtype).add(self if include_self else _inv_mask(self, 0))
if reduce == "prod": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype).mul(self if include_self else _inv_mask(self, 1))
if reduce == "amax": return mask.where(src, m := dtypes.min(src.dtype)).max(-1).maximum(self if include_self else _inv_mask(self, m))
if reduce == "amin": return mask.where(src, m := dtypes.max(src.dtype)).min(-1).minimum(self if include_self else _inv_mask(self, m))
if reduce == "mean":
count = mask.where(1, 0).sum(-1, acc_dtype=self.dtype).add(1 if include_self else _inv_mask(1, 0))
return mask.where(src, 0).sum(-1, acc_dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)).div(count)
raise TypeError(f"{reduce=} must be one of 'sum', 'prod', 'mean', 'amax', 'amin'")
# ***** unary ops *****
def logical_not(self):