Revert "Fix max nan (#1298)" (#1334)

This reverts commit 50774470b2.
This commit is contained in:
George Hotz
2023-07-23 20:41:28 -07:00
committed by GitHub
parent 50774470b2
commit 086382b64e
2 changed files with 4 additions and 12 deletions

View File

@@ -381,8 +381,8 @@ class Tensor:
return ret if keepdim else ret.reshape(shape=shape)
def sum(self, axis=None, keepdim=False): return self._reduce(mlops.Sum, axis, keepdim)
def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, keepdim)
def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))
def max(self, axis=None, keepdim=False, ignore_nan=True): return self._reduce(mlops.Max, axis, keepdim) if ignore_nan else self._reduce(mlops.Max, axis, keepdim) + (self.isnan() * np.nan).sum()
def mean(self, axis=None, keepdim=False):
out = self.sum(axis=axis, keepdim=keepdim)
@@ -672,7 +672,6 @@ class Tensor:
def element_size(self) -> int: return self.dtype.itemsize
def nbytes(self) -> int: return self.numel() * self.element_size()
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
def isnan(self) -> Tensor: return (self != self)
# register functions to move between devices
for device in Device._buffers: