hotfix don't overwrite acc dtype in scatter_reduce (#10606)

dtype is inferred by individul reduce
This commit is contained in:
chenyu
2025-06-02 21:17:01 -04:00
committed by GitHub
parent ba02a6331e
commit 26dee71bc1

View File

@@ -2676,14 +2676,13 @@ class Tensor(MathTrait):
"""
src, mask = self._pre_scatter(dim, index, src)
def _inv_mask(a:Tensor|ConstType, b:Tensor|ConstType) -> Tensor: return mask.any(-1).logical_not().where(a, b)
# TODO: should not overwrite dtype here?
if reduce == "sum": return mask.where(src, 0).sum(-1, dtype=self.dtype).add(self if include_self else _inv_mask(self, 0))
if reduce == "prod": return mask.where(src, 1).prod(-1, dtype=self.dtype).mul(self if include_self else _inv_mask(self, 1))
if reduce == "sum": return mask.where(src, 0).sum(-1).add(self if include_self else _inv_mask(self, 0))
if reduce == "prod": return mask.where(src, 1).prod(-1).mul(self if include_self else _inv_mask(self, 1))
if reduce == "amax": return mask.where(src, m := dtypes.min(src.dtype)).max(-1).maximum(self if include_self else _inv_mask(self, m))
if reduce == "amin": return mask.where(src, m := dtypes.max(src.dtype)).min(-1).minimum(self if include_self else _inv_mask(self, m))
if reduce == "mean":
count = mask.where(1, 0).sum(-1, dtype=self.dtype).add(1 if include_self else _inv_mask(1, 0))
return mask.where(src, 0).sum(-1, dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)).div(count)
count = mask.where(1, 0).sum(-1).add(1 if include_self else _inv_mask(1, 0))
return mask.where(src, 0).sum(-1).add(self if include_self else _inv_mask(self, 0)).div(count)
raise RuntimeError(f"{reduce=} must be one of 'sum', 'prod', 'mean', 'amax', 'amin'")
def sort(self, dim:int=-1, descending:bool=False) -> tuple[Tensor, Tensor]: