mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
hotfix don't overwrite acc dtype in scatter_reduce (#10606)
dtype is inferred by individul reduce
This commit is contained in:
@@ -2676,14 +2676,13 @@ class Tensor(MathTrait):
|
||||
"""
|
||||
src, mask = self._pre_scatter(dim, index, src)
|
||||
def _inv_mask(a:Tensor|ConstType, b:Tensor|ConstType) -> Tensor: return mask.any(-1).logical_not().where(a, b)
|
||||
# TODO: should not overwrite dtype here?
|
||||
if reduce == "sum": return mask.where(src, 0).sum(-1, dtype=self.dtype).add(self if include_self else _inv_mask(self, 0))
|
||||
if reduce == "prod": return mask.where(src, 1).prod(-1, dtype=self.dtype).mul(self if include_self else _inv_mask(self, 1))
|
||||
if reduce == "sum": return mask.where(src, 0).sum(-1).add(self if include_self else _inv_mask(self, 0))
|
||||
if reduce == "prod": return mask.where(src, 1).prod(-1).mul(self if include_self else _inv_mask(self, 1))
|
||||
if reduce == "amax": return mask.where(src, m := dtypes.min(src.dtype)).max(-1).maximum(self if include_self else _inv_mask(self, m))
|
||||
if reduce == "amin": return mask.where(src, m := dtypes.max(src.dtype)).min(-1).minimum(self if include_self else _inv_mask(self, m))
|
||||
if reduce == "mean":
|
||||
count = mask.where(1, 0).sum(-1, dtype=self.dtype).add(1 if include_self else _inv_mask(1, 0))
|
||||
return mask.where(src, 0).sum(-1, dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)).div(count)
|
||||
count = mask.where(1, 0).sum(-1).add(1 if include_self else _inv_mask(1, 0))
|
||||
return mask.where(src, 0).sum(-1).add(self if include_self else _inv_mask(self, 0)).div(count)
|
||||
raise RuntimeError(f"{reduce=} must be one of 'sum', 'prod', 'mean', 'amax', 'amin'")
|
||||
|
||||
def sort(self, dim:int=-1, descending:bool=False) -> tuple[Tensor, Tensor]:
|
||||
|
||||
Reference in New Issue
Block a user