mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Add Tensor.scatter_reduce (#8947)
* pytorch scatter -> scatter_reduce * WIP scatter_reduce implementation * _pre_scatter return type hint * split out src, mask to satisfy linter * Add src cast back in * dict of lambdas instead of ifs * sum and prod reduction ops with include_self * add reduce arg error message * add amax and amin reduction ops * Fix include_self for higher dims * Simplify * Simplify amax and amin too * Pull include_self logic out into _inv_mask function * reduce arg cannot be None for scatter_reduce * Fix self-mask issue * Add mean reduce op * Add tests * any() not needed here * remove comment * End support for Tensor src with reduce arg in tinygrad scatter * Process index, dim inside actual functions * Add scatter_reduce to onnx * Add excluded onnx ScatterElements reduction tests back in * Save 2 lines on the mask helpers * Update docs * Add include_self=False tests * cleanup * Remove unneeded helper function --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -32,6 +32,7 @@
|
||||
::: tinygrad.Tensor.tril
|
||||
::: tinygrad.Tensor.interpolate
|
||||
::: tinygrad.Tensor.scatter
|
||||
::: tinygrad.Tensor.scatter_reduce
|
||||
|
||||
## Neural Network (functional)
|
||||
|
||||
|
||||
@@ -691,9 +691,10 @@ def get_onnx_ops():
|
||||
else: raise NotImplementedError("reduction doesn't support max or min")
|
||||
return x
|
||||
|
||||
def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul"]="none"):
|
||||
def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul", "min", "max"]="none"):
|
||||
indices = (indices < 0).where(x.shape[axis], 0) + indices
|
||||
return x.scatter(axis, indices, updates, {"none":None, "mul": "multiply"}.get(reduction, reduction))
|
||||
if reduction == "none": return x.scatter(axis, indices, updates)
|
||||
return x.scatter_reduce(axis, indices, updates, {"add": "sum", "mul": "prod", "min": "amin", "max": "amax"}.get(reduction))
|
||||
def GatherElements(x:Tensor, indices:Tensor, axis:int):
|
||||
indices = (indices < 0).where(x.shape[axis], 0) + indices
|
||||
return x.gather(axis, indices)
|
||||
|
||||
2
test/external/external_test_onnx_backend.py
vendored
2
test/external/external_test_onnx_backend.py
vendored
@@ -175,9 +175,7 @@ backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_value_only_mapping_cp
|
||||
backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_mapping_cpu') # bad data type string
|
||||
backend_test.exclude('test_group_normalization_*') # numerical inaccuracy problem. Current Group Normalization OP fails test
|
||||
|
||||
backend_test.exclude('test_scatter_elements_with_reduction_min_cpu') # min not yet supported
|
||||
backend_test.exclude('test_scatternd_min_cpu') # min not yet supported
|
||||
backend_test.exclude('test_scatter_elements_with_reduction_max_cpu') # max not yet supported
|
||||
backend_test.exclude('test_scatternd_max_cpu') # max not yet supported
|
||||
|
||||
if Device.DEFAULT in ['GPU', 'METAL']:
|
||||
|
||||
@@ -2572,12 +2572,6 @@ class TestOps(unittest.TestCase):
|
||||
vals=[[1.,2.,3.,4.], [1.,0.]])
|
||||
|
||||
def test_scatter_add(self):
|
||||
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
|
||||
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
for dim in (0,1,2,-1,-2,-3):
|
||||
helper_test_op([(4,5,6), (4,5,6)], lambda x,src: x.scatter(dim=dim, index=b, src=src, reduce="add"),
|
||||
lambda x,src: x.scatter(dim=dim, index=a, src=src, reduce="add"), forward_only=True)
|
||||
|
||||
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
|
||||
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
helper_test_op([(4,5,6)], lambda x: x.scatter(dim=1, index=b, value=float("inf"), reduce="add"),
|
||||
@@ -2592,10 +2586,6 @@ class TestOps(unittest.TestCase):
|
||||
def test_scatter_mul(self):
|
||||
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
|
||||
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
for dim in (0,1,2,-1,-2,-3):
|
||||
helper_test_op([(4,5,6), (4,5,6)], lambda x,src: x.scatter(dim=dim, index=b, src=src, reduce="multiply"),
|
||||
lambda x,src: x.scatter(dim=dim, index=a, src=src, reduce="multiply"), forward_only=True)
|
||||
|
||||
helper_test_op([(4,5,6)], lambda x: x.scatter(dim=1, index=b, value=float("inf"), reduce="multiply"),
|
||||
lambda x: x.scatter(dim=1, index=a, src=float("inf"), reduce="multiply"), forward_only=True)
|
||||
|
||||
@@ -2605,10 +2595,74 @@ class TestOps(unittest.TestCase):
|
||||
lambda x: x.scatter(1, b, float("nan"), reduce="multiply"),
|
||||
lambda x: x.scatter(1, a, float("nan"), reduce="multiply"), forward_only=True,)
|
||||
|
||||
def test_scatter_no_reduce_tensor_src(self):
|
||||
with self.assertRaises(TypeError):
|
||||
Tensor.ones(4).scatter(dim=1, index=Tensor([0]), src=Tensor.ones(4), reduce="add")
|
||||
|
||||
def test_scatter_reduce_sum(self):
|
||||
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
|
||||
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
for dim in (0,1,2,-1,-2,-3):
|
||||
helper_test_op([(4,5,6), (4,5,6)],
|
||||
lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce="sum"),
|
||||
lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce="sum"), forward_only=True)
|
||||
helper_test_op([(4,5,6), (4,5,6)],
|
||||
lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce="sum", include_self=False),
|
||||
lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce="sum", include_self=False), forward_only=True)
|
||||
|
||||
def test_scatter_reduce_prod(self):
|
||||
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
|
||||
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
for dim in (0,1,2,-1,-2,-3):
|
||||
helper_test_op([(4,5,6), (4,5,6)],
|
||||
lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce="prod"),
|
||||
lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce="prod"), forward_only=True)
|
||||
helper_test_op([(4,5,6), (4,5,6)],
|
||||
lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce="prod", include_self=False),
|
||||
lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce="prod", include_self=False), forward_only=True)
|
||||
|
||||
x = Tensor.zeros([4,5,6]).float()
|
||||
y = torch.zeros([4,5,6]).float()
|
||||
helper_test_op([(4,5,6)], lambda src: y.scatter(dim=1, index=b, src=src, reduce="multiply"),
|
||||
lambda src: x.scatter(dim=1, index=a, src=src, reduce="multiply"), forward_only=True)
|
||||
helper_test_op([(4,5,6)],
|
||||
lambda src: y.scatter_reduce(dim=1, index=b, src=src, reduce="prod"),
|
||||
lambda src: x.scatter_reduce(dim=1, index=a, src=src, reduce="prod"), forward_only=True)
|
||||
|
||||
def test_scatter_reduce_mean(self):
|
||||
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
|
||||
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
for dim in (0,1,2,-1,-2,-3):
|
||||
helper_test_op([(4,5,6), (4,5,6)],
|
||||
lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce="mean"),
|
||||
lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce="mean"), forward_only=True)
|
||||
helper_test_op([(4,5,6), (4,5,6)],
|
||||
lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce="mean", include_self=False),
|
||||
lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce="mean", include_self=False), forward_only=True)
|
||||
|
||||
def test_scatter_reduce_amax(self):
|
||||
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
|
||||
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
for dim in (0,1,2,-1,-2,-3):
|
||||
helper_test_op([(4,5,6), (4,5,6)],
|
||||
lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce="amax"),
|
||||
lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce="amax"), forward_only=True)
|
||||
helper_test_op([(4,5,6), (4,5,6)],
|
||||
lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce="amax", include_self=False),
|
||||
lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce="amax", include_self=False), forward_only=True)
|
||||
|
||||
def test_scatter_reduce_amin(self):
|
||||
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
|
||||
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
for dim in (0,1,2,-1,-2,-3):
|
||||
helper_test_op([(4,5,6), (4,5,6)],
|
||||
lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce="amin"),
|
||||
lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce="amin"), forward_only=True)
|
||||
helper_test_op([(4,5,6), (4,5,6)],
|
||||
lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce="amin", include_self=False),
|
||||
lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce="amin", include_self=False), forward_only=True)
|
||||
|
||||
def test_scatter_reduce_invalid_reduce_op(self):
|
||||
with self.assertRaises(TypeError):
|
||||
Tensor.ones(4).scatter_reduce(dim=0, index=Tensor([0]), src=Tensor.ones(4), reduce="INVALID")
|
||||
|
||||
def test_scaled_dot_product_attention(self):
|
||||
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], torch.nn.functional.scaled_dot_product_attention, Tensor.scaled_dot_product_attention)
|
||||
|
||||
@@ -2430,11 +2430,26 @@ class Tensor(SimpleMathTrait):
|
||||
x = x.gather(i, index)
|
||||
return x.cast(self.dtype)
|
||||
|
||||
def _pre_scatter(self, dim:int, index:Tensor, src:Tensor) -> tuple[Tensor, Tensor]:
|
||||
assert index.ndim == self.ndim == src.ndim, f"self.ndim, index.ndim and src.dim must all equal, {self.ndim=} {index.ndim=} {src.ndim=}"
|
||||
assert all((d == dim or self_ >= index_) and src_ >= index_ for d,(self_,index_,src_) in enumerate(zip(self.shape, index.shape, src.shape))), \
|
||||
f"All dimensions of {index.shape=} should be <= to all dimensions of {src.shape=} and all dimensions except dimension {dim} of {self.shape=}"
|
||||
# shrink src to index shape to shrink away the unused values
|
||||
src = src.shrink(tuple((0,s) for s in index.shape))
|
||||
# prepare src and mask for reduce with respect to dim
|
||||
src = src.unsqueeze(-1).expand(*src.shape, self.shape[dim]).transpose(-1, dim)
|
||||
mask = index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).transpose(-1, dim)
|
||||
# pad src and mask to self.shape so that reduce can be done with padded values as no-ops
|
||||
src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask))
|
||||
return src, mask
|
||||
|
||||
def scatter(self, dim:int, index:Tensor, src:Union[Tensor, ConstType], reduce:Union[None, Literal['multiply'], Literal['add']]=None) -> Tensor:
|
||||
"""
|
||||
Scatters `src` values along an axis specified by `dim`.
|
||||
Apply `add` or `multiply` reduction operation with `reduce`.
|
||||
|
||||
NOTE: To use the `reduce` argument with a Tensor `src`, see `Tensor.scatter_reduce`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
src = Tensor.arange(1, 11).reshape(2, 5)
|
||||
print(src.numpy())
|
||||
@@ -2455,22 +2470,57 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
if reduce not in {None, "add", "multiply"}: raise TypeError(f"{reduce=} must be one of None, 'multiply', or 'add'")
|
||||
index, dim = index.to(self.device), self._resolve_dim(dim)
|
||||
if reduce and isinstance(src, Tensor): raise TypeError("Tensor src is not supported with reduce arg. see scatter_reduce")
|
||||
src = src.cast(self.dtype) if isinstance(src, Tensor) else Tensor(src, device=self.device, dtype=self.dtype)._broadcast_to(index.shape)
|
||||
assert index.ndim == self.ndim == src.ndim, f"self.ndim, index.ndim and src.dim must all equal, {self.ndim=} {index.ndim=} {src.ndim=}"
|
||||
assert all((d == dim or self_ >= index_) and src_ >= index_ for d,(self_,index_,src_) in enumerate(zip(self.shape, index.shape, src.shape))), \
|
||||
f"All dimensions of {index.shape=} should be <= to all dimensions of {src.shape=} and all dimensions except dimension {dim} of {self.shape=}"
|
||||
# shrink src to index shape to shrink away the unused values
|
||||
src = src.shrink(tuple((0,s) for s in index.shape))
|
||||
# prepare src and mask for reduce with respect to dim
|
||||
src = src.unsqueeze(-1).expand(*src.shape, self.shape[dim]).transpose(-1, dim)
|
||||
mask = index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).transpose(-1, dim)
|
||||
# pad src and mask to self.shape so that reduce can be done with padded values as no-ops
|
||||
src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask))
|
||||
index, dim = index.to(self.device), self._resolve_dim(dim)
|
||||
src, mask = self._pre_scatter(dim, index, src)
|
||||
if reduce == "add": return mask.where(src, 0).sum(-1, acc_dtype=self.dtype) + self
|
||||
if reduce == "multiply": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype) * self
|
||||
return _masked_setitem(self, src, mask, (-1,))
|
||||
|
||||
def scatter_reduce(self, dim:int, index:Tensor, src:Tensor, reduce:Literal["sum", "prod", "mean", "amax", "amin"],
|
||||
include_self:bool=True) -> Tensor:
|
||||
"""
|
||||
Scatters `src` values along an axis specified by `dim`.
|
||||
Apply `"sum"`, `"prod"`, `"mean"`, `"amax"`, or `"amin"` reduction operations with `reduce`.
|
||||
|
||||
Set `include_self=False` to exclude values in the `self` Tensor from the reduction.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
src = Tensor.arange(1, 11).cast(dtypes.float).reshape(2, 5)
|
||||
print(src.numpy())
|
||||
index = Tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]])
|
||||
print(index.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='sum').numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='prod').numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='mean', include_self=False).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor([[-10, 20, 0, 5, 10]], dtype=src.dtype).scatter_reduce(0, index, src, reduce='amax').numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor([[-10, 20, 0, 5, 10]], dtype=src.dtype).scatter_reduce(0, index, src, reduce='amin').numpy())
|
||||
```
|
||||
"""
|
||||
src = src.cast(self.dtype)
|
||||
index, dim = index.to(self.device), self._resolve_dim(dim)
|
||||
src, mask = self._pre_scatter(dim, index, src)
|
||||
def _inv_mask(a:Union[Tensor, ConstType], b:Union[Tensor, ConstType]) -> Tensor: return mask.any(-1).logical_not().where(a, b)
|
||||
if reduce == "sum": return mask.where(src, 0).sum(-1, acc_dtype=self.dtype).add(self if include_self else _inv_mask(self, 0))
|
||||
if reduce == "prod": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype).mul(self if include_self else _inv_mask(self, 1))
|
||||
if reduce == "amax": return mask.where(src, m := dtypes.min(src.dtype)).max(-1).maximum(self if include_self else _inv_mask(self, m))
|
||||
if reduce == "amin": return mask.where(src, m := dtypes.max(src.dtype)).min(-1).minimum(self if include_self else _inv_mask(self, m))
|
||||
if reduce == "mean":
|
||||
count = mask.where(1, 0).sum(-1, acc_dtype=self.dtype).add(1 if include_self else _inv_mask(1, 0))
|
||||
return mask.where(src, 0).sum(-1, acc_dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)).div(count)
|
||||
raise TypeError(f"{reduce=} must be one of 'sum', 'prod', 'mean', 'amax', 'amin'")
|
||||
|
||||
# ***** unary ops *****
|
||||
|
||||
def logical_not(self):
|
||||
|
||||
Reference in New Issue
Block a user