mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 18:35:12 -05:00
Fix max nan (#1298)
* Fix max nan * Adds nan check option to max function * Calls to max can pass in "ignore_nan=True" argument * Added max nan CI tests * Fix max nan * Adds nan check option to max function * Calls to max can pass in "ignore_nan=True" argument * Added max nan CI tests * Turned off due to the need for granularity
This commit is contained in:
@@ -381,8 +381,8 @@ class Tensor:
|
||||
return ret if keepdim else ret.reshape(shape=shape)
|
||||
|
||||
def sum(self, axis=None, keepdim=False): return self._reduce(mlops.Sum, axis, keepdim)
|
||||
def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, keepdim)
|
||||
def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))
|
||||
def max(self, axis=None, keepdim=False, ignore_nan=True): return self._reduce(mlops.Max, axis, keepdim) if ignore_nan else self._reduce(mlops.Max, axis, keepdim) + (self.isnan() * np.nan).sum()
|
||||
|
||||
def mean(self, axis=None, keepdim=False):
|
||||
out = self.sum(axis=axis, keepdim=keepdim)
|
||||
@@ -672,6 +672,7 @@ class Tensor:
|
||||
def element_size(self) -> int: return self.dtype.itemsize
|
||||
def nbytes(self) -> int: return self.numel() * self.element_size()
|
||||
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
|
||||
def isnan(self) -> Tensor: return (self != self)
|
||||
|
||||
# register functions to move between devices
|
||||
for device in Device._buffers:
|
||||
|
||||
Reference in New Issue
Block a user