From 2029cb704753bf88bf9341b49751278e9f6ec191 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 7 Jul 2024 13:04:22 -0400 Subject: [PATCH] support passing None to Tensor.clip (#5319) passing None for no upper bound or no lower bound --- test/test_ops.py | 5 ++++- tinygrad/tensor.py | 7 +++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 5dd7f1c114..d71dca7d12 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1698,7 +1698,10 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65)], lambda x: x.clip(10, 100)) helper_test_op([(45,65)], lambda x: x.clip(0, 0.1)) helper_test_op([(45,65)], lambda x: x.clip(-0.3, -0.2)) - helper_test_op([(45,65)], lambda x: x.clip(3, 0)) + helper_test_op([(45,65)], lambda x: x.clip(3, 0)) # min > max + helper_test_op([(45,65)], lambda x: x.clip(None, 0)) + helper_test_op([(45,65)], lambda x: x.clip(0, None)) + self.helper_test_exception([(45,65)], lambda x: x.clip(None, None), lambda x: x.clip(None, None), RuntimeError) def test_matvecmat(self): helper_test_op([(1,128), (128,128), (128,128)], lambda x,y,z: (x@y).relu()@z) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index b7cfac0624..f69a5af878 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2063,15 +2063,18 @@ class Tensor: ``` """ return self*self - def clip(self, min_, max_): + def clip(self, min_=None, max_=None): """ Clips (clamps) the values in the tensor between `min_` and `max_` element-wise. + If `min_` is `None`, there is no lower bound. If `max_` is None, there is no upper bound. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).clip(-1, 1).numpy()) ``` """ - return self.maximum(min_).minimum(max_) + if min_ is None and max_ is None: raise RuntimeError("at least one of 'min_' or 'max_' must not be None") + ret = self.maximum(min_) if min_ is not None else self + return ret.minimum(max_) if max_ is not None else ret def sign(self): """ Returns the sign of the tensor element-wise.