Fix max nan (#1298)

* Fix max nan

* Adds nan check option to max function
* Calls to max can pass in "ignore_nan=True" argument
* Added max nan CI tests

* Fix max nan

* Adds nan check option to max function
* Calls to max can pass in "ignore_nan=True" argument
* Added max nan CI tests
* Turned off due to the need for granularity
This commit is contained in:
uncommonSensor
2023-07-23 19:39:44 -07:00
committed by GitHub
parent a0965ee198
commit 50774470b2
2 changed files with 12 additions and 4 deletions

View File

@@ -381,8 +381,8 @@ class Tensor:
return ret if keepdim else ret.reshape(shape=shape)
def sum(self, axis=None, keepdim=False): return self._reduce(mlops.Sum, axis, keepdim)
def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, keepdim)
def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))
def max(self, axis=None, keepdim=False, ignore_nan=True): return self._reduce(mlops.Max, axis, keepdim) if ignore_nan else self._reduce(mlops.Max, axis, keepdim) + (self.isnan() * np.nan).sum()
def mean(self, axis=None, keepdim=False):
out = self.sum(axis=axis, keepdim=keepdim)
@@ -672,6 +672,7 @@ class Tensor:
def element_size(self) -> int: return self.dtype.itemsize
def nbytes(self) -> int: return self.numel() * self.element_size()
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
def isnan(self) -> Tensor: return (self != self)
# register functions to move between devices
for device in Device._buffers: