simpler abs and sign (#4606)

This commit is contained in:
chenyu
2024-05-15 14:33:09 -04:00
committed by GitHub
parent 5ba611787d
commit a5e157f663

View File

@@ -1151,8 +1151,8 @@ class Tensor:
def lerp(self, end: Tensor, weight: Union[Tensor, float]) -> Tensor: return self + (end - self) * weight
def square(self): return self*self
def clip(self, min_, max_): return self.maximum(min_).minimum(max_)
def abs(self): return self.relu() + (-self).relu()
def sign(self): return ((self.float()) / (self.float().abs() + 1e-12)).cast(self.dtype)
def abs(self): return (self < 0).detach().where(-self, self)
def sign(self): return (self == 0).detach().where(0, self / self.abs()).cast(self.dtype)
def reciprocal(self): return F.Reciprocal.apply(self.cast(least_upper_float(self.dtype)))
# ***** activation functions (unary) *****