diff --git a/test/test_ops.py b/test/test_ops.py index d16ce26ab2..067f6d2198 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1095,17 +1095,10 @@ class TestOps(unittest.TestCase): helper_test_op([(4,4)], lambda x: x[:, 1:2][0:1]) helper_test_op([(4,4)], lambda x: x[:, 1:2][:, 0:1]) - @unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU does not support nan/inf") - def test_max_nan(self): - n = Tensor([1, float("nan")]).max(ignore_nan=False).numpy() + @unittest.skip("this test is broken #862") + def test_max_inf(self): + n = Tensor([1, float("nan")]).max().numpy() assert math.isnan(n.item()), f"{n.item()} is not nan" - n = Tensor([float("nan"), 1]).max(ignore_nan=False).numpy() - assert math.isnan(n.item()), f"{n.item()} is not nan" - - @unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU does not support nan/inf") - def test_max_float4_nan(self): - n = Tensor(np.array([[1.0, 2.0, np.nan, 4.0], [5.0, 6.0, 7.0, 8.0]], dtype=np.float32)).max(ignore_nan=False).numpy() - assert math.isnan(n.item()), f"tinygrad max: {n.item()} is not nan" def test_inf_where(self): x = Tensor.full((3, 3), float("inf")) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index bd9995ae3a..e5c358ca73 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -381,8 +381,8 @@ class Tensor: return ret if keepdim else ret.reshape(shape=shape) def sum(self, axis=None, keepdim=False): return self._reduce(mlops.Sum, axis, keepdim) + def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, keepdim) def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim)) - def max(self, axis=None, keepdim=False, ignore_nan=True): return self._reduce(mlops.Max, axis, keepdim) if ignore_nan else self._reduce(mlops.Max, axis, keepdim) + (self.isnan() * np.nan).sum() def mean(self, axis=None, keepdim=False): out = self.sum(axis=axis, keepdim=keepdim) @@ -672,7 +672,6 @@ class Tensor: def element_size(self) -> int: return self.dtype.itemsize def nbytes(self) -> int: return self.numel() * self.element_size() def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype) - def isnan(self) -> Tensor: return (self != self) # register functions to move between devices for device in Device._buffers: