From e8879f7e31d9c965391efad069e84e0f11504a1a Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 2 Dec 2025 17:58:32 -0500 Subject: [PATCH] match torch clamp backward (#13533) * match torch clamp backward * fix PYTHON --- test/test_ops.py | 3 +++ tinygrad/tensor.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 11653323c1..995420e56b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2729,6 +2729,9 @@ class TestOps(unittest.TestCase): def test_clip(self): helper_test_op([(45,65)], lambda x: x.clip(-2.3, 1.2)) + # NOTE: torch set backward to 1 at the boundaries + # https://github.com/pytorch/pytorch/blob/7a41b66367c38d0af3e8a90f7be48d6b281e7bca/tools/autograd/derivatives.yaml#L421 + helper_test_op(None, lambda x: x.clip(-2.5, 1.5), vals=[[-3.0, -2.5, 0, 1.5, 2]]) helper_test_op([(45,65)], lambda x: x.clip(0, 0)) helper_test_op([(45,65)], lambda x: x.clip(10, 100)) helper_test_op([(45,65)], lambda x: x.clip(0, 0.1)) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 3428b8054e..77801f5fe4 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3020,8 +3020,8 @@ class Tensor(OpMixin): ``` """ if min_ is None and max_ is None: raise RuntimeError("at least one of 'min_' or 'max_' must not be None") - ret = self.maximum(min_) if min_ is not None else self - return ret.minimum(max_) if max_ is not None else ret + ret = (self < min_).where(min_, self) if min_ is not None else self + return (ret > max_).where(max_, ret) if max_ is not None else ret def clip(self, min_=None, max_=None) -> Tensor: """