diff --git a/docs/tensor/ops.md b/docs/tensor/ops.md index 7db33a9304..488b632bbf 100644 --- a/docs/tensor/ops.md +++ b/docs/tensor/ops.md @@ -6,6 +6,7 @@ ::: tinygrad.Tensor.min ::: tinygrad.Tensor.any ::: tinygrad.Tensor.all +::: tinygrad.Tensor.isclose ::: tinygrad.Tensor.mean ::: tinygrad.Tensor.var ::: tinygrad.Tensor.std diff --git a/test/test_ops.py b/test/test_ops.py index 80202a5c2e..b36eb0bd1c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1268,6 +1268,21 @@ class TestOps(unittest.TestCase): def test_all_zero_axis(self): helper_test_op([(1,0,3,0,5)], lambda x: x.all(axis=(1,3)), forward_only=True) + def test_isclose(self): + helper_test_op([(3, 4, 5, 6)], lambda x: x.isclose(x), forward_only=True) + helper_test_op([(3, 4, 5, 6)], lambda x: x.isclose(x, equal_nan=True), forward_only=True) + helper_test_op(None, lambda x: x.isclose(x + 1e-6), vals=[[1.0, 2.0, 3.0]], forward_only=True) + helper_test_op(None, lambda x: x.isclose(x + 0.1), vals=[[1.0, 2.0, 3.0]], forward_only=True) + helper_test_op(None, lambda x: x.isclose(x + 0.1, rtol=0.2, atol=0.0), vals=[[1.0, 2.0, 3.0]], forward_only=True) + helper_test_op(None, lambda x: x.isclose(x + 1e-9), vals=[[0.0, 0.0, 0.0]], forward_only=True) + helper_test_op([(2, 3, 4)], lambda x: x.isclose(x + 1e-6), forward_only=True) + + def test_isclose_edge_cases(self): + helper_test_op(None, lambda x: x.isclose(x), vals=[[float("inf"), float("-inf"), 1.0]], forward_only=True) + helper_test_op(None, lambda x: x.isclose(x, equal_nan=True), vals=[[float("inf"), float("-inf"), 1.0]], forward_only=True) + helper_test_op(None, lambda x: x.isclose(x), vals=[[float("nan"), 1.0]], forward_only=True) + helper_test_op(None, lambda x: x.isclose(x, equal_nan=True), vals=[[float("nan"), 1.0]], forward_only=True) + def test_mean(self): helper_test_op([(3,4,5,6)], lambda x: x.mean()) helper_test_op([()], lambda x: x.mean()) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 33ec0bc2fe..6b1262acdb 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1715,6 +1715,26 @@ class Tensor(SimpleMathTrait): """ return self.logical_not().any(axis, keepdim).logical_not() + def isclose(self, other:Tensor, rtol:float=1e-05, atol:float=1e-08, equal_nan=False) -> Tensor: + """ + Returns a new tensor with element-wise comparison of closeness to `other` within a tolerance. + + The `rtol` and `atol` keyword arguments control the relative and absolute tolerance of the comparison. + + By default, two `NaN` values are not close to each other. If `equal_nan` is `True`, two `NaN` values are considered close. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([1, 2, 3]) + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([1, 2, 3]) + print(t.isclose(Tensor([1, 2, 3.1])).numpy()) + ``` + """ + is_close = (self - other).abs() <= atol + rtol * other.abs() + return is_close | (self.isnan() & other.isnan()) if equal_nan else is_close & (self.isnan() | other.isnan()).logical_not() + def mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False): """ Returns the mean value of the tensor along the specified axis or axes.