From 657e642e3a45795013f644fa608820f8e5e84f2c Mon Sep 17 00:00:00 2001 From: Steven Anderson <34435120+stevenandersonz@users.noreply.github.com> Date: Sun, 4 Jun 2023 09:01:01 -0700 Subject: [PATCH] Fixed test suite for Clip (#912) * Fixed test suite for Clip * fixed issue with clip when taking large negative numbers as min * Remove typings --- extra/onnx_ops.py | 6 ++++-- tinygrad/tensor.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index 1ade5251d8..b51eb1d38f 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -152,8 +152,10 @@ def Softmax_1(input, axis=1): return input.softmax(axis) def Softmax_13(input, axis=-1): return input.softmax(axis) Softmax = {1: Softmax_1, 13: Softmax_13} # Softmax default axis changed def LogSoftmax(input, axis=-1): return input.log_softmax(axis) -def Clip(input, min=-3.4e38, max=3.4e38): return input.clip(min, max) - +def Clip(input, min=None, max=None): + if min is None: min = -3.4e38 + if max is None: max = 3.4e38 + return input.clip(min, max) def Sin(x): return x.sin() def Cos(x): return x.cos() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 877bc03b7a..d45d0c781e 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -478,7 +478,7 @@ class Tensor: def sqrt(self): return self.pow(0.5) def rsqrt(self): return self.pow(-0.5) def square(self): return self*self - def clip(self, min_, max_): return ((self-min_).relu()+min_) - (self-max_).relu() + def clip(self, min_, max_): return self.maximum(min_).minimum(max_) def abs(self): return self.relu() + (-self).relu() def sign(self): return self / (self.abs() + 1e-10) def reciprocal(self): return 1.0/self