diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 7e55214099..20e7d45b5a 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2676,14 +2676,13 @@ class Tensor(MathTrait): """ src, mask = self._pre_scatter(dim, index, src) def _inv_mask(a:Tensor|ConstType, b:Tensor|ConstType) -> Tensor: return mask.any(-1).logical_not().where(a, b) - # TODO: should not overwrite dtype here? - if reduce == "sum": return mask.where(src, 0).sum(-1, dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)) - if reduce == "prod": return mask.where(src, 1).prod(-1, dtype=self.dtype).mul(self if include_self else _inv_mask(self, 1)) + if reduce == "sum": return mask.where(src, 0).sum(-1).add(self if include_self else _inv_mask(self, 0)) + if reduce == "prod": return mask.where(src, 1).prod(-1).mul(self if include_self else _inv_mask(self, 1)) if reduce == "amax": return mask.where(src, m := dtypes.min(src.dtype)).max(-1).maximum(self if include_self else _inv_mask(self, m)) if reduce == "amin": return mask.where(src, m := dtypes.max(src.dtype)).min(-1).minimum(self if include_self else _inv_mask(self, m)) if reduce == "mean": - count = mask.where(1, 0).sum(-1, dtype=self.dtype).add(1 if include_self else _inv_mask(1, 0)) - return mask.where(src, 0).sum(-1, dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)).div(count) + count = mask.where(1, 0).sum(-1).add(1 if include_self else _inv_mask(1, 0)) + return mask.where(src, 0).sum(-1).add(self if include_self else _inv_mask(self, 0)).div(count) raise RuntimeError(f"{reduce=} must be one of 'sum', 'prod', 'mean', 'amax', 'amin'") def sort(self, dim:int=-1, descending:bool=False) -> tuple[Tensor, Tensor]: