diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index 1ade5251d8..b51eb1d38f 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -152,8 +152,10 @@ def Softmax_1(input, axis=1): return input.softmax(axis) def Softmax_13(input, axis=-1): return input.softmax(axis) Softmax = {1: Softmax_1, 13: Softmax_13} # Softmax default axis changed def LogSoftmax(input, axis=-1): return input.log_softmax(axis) -def Clip(input, min=-3.4e38, max=3.4e38): return input.clip(min, max) - +def Clip(input, min=None, max=None): + if min is None: min = -3.4e38 + if max is None: max = 3.4e38 + return input.clip(min, max) def Sin(x): return x.sin() def Cos(x): return x.cos() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 877bc03b7a..d45d0c781e 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -478,7 +478,7 @@ class Tensor: def sqrt(self): return self.pow(0.5) def rsqrt(self): return self.pow(-0.5) def square(self): return self*self - def clip(self, min_, max_): return ((self-min_).relu()+min_) - (self-max_).relu() + def clip(self, min_, max_): return self.maximum(min_).minimum(max_) def abs(self): return self.relu() + (-self).relu() def sign(self): return self / (self.abs() + 1e-10) def reciprocal(self): return 1.0/self