mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 02:21:40 -05:00
Match torch on fractional negative base pow (#1352)
* feat: match torch on fractional negative base pow * feat: tests for trunc
This commit is contained in:
@@ -513,12 +513,9 @@ class Tensor:
|
||||
def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype).where(Tensor.zeros_like(self), self)
|
||||
|
||||
# ***** math functions (unary) *****
|
||||
def ceil(self: Tensor) -> Tensor:
|
||||
b = self.cast(dtypes.int32).contiguous().cast(self.dtype)
|
||||
return (self > b).where(b+1, b)
|
||||
def floor(self: Tensor) -> Tensor:
|
||||
b = self.cast(dtypes.int32).contiguous().cast(self.dtype)
|
||||
return (self < b).where(b-1, b)
|
||||
def trunc(self: Tensor) -> Tensor: return self.cast(dtypes.int32).contiguous().cast(self.dtype)
|
||||
def ceil(self: Tensor) -> Tensor: return (self > (b := self.trunc())).where(b+1, b)
|
||||
def floor(self: Tensor) -> Tensor: return (self < (b := self.trunc())).where(b-1, b)
|
||||
|
||||
def __neg__(self): return 0.0-self
|
||||
def square(self): return self*self
|
||||
@@ -580,7 +577,10 @@ class Tensor:
|
||||
sign = (x * pi).cos() if isinstance(x, Tensor) else cos(x * pi) if not reverse else (self * pi).cos()
|
||||
# we only need to correct the sign if the base is negative
|
||||
base_sign = ((self.sign() if not reverse else x.sign() if isinstance(x, Tensor) else copysign(1, x)) - 1) / -2
|
||||
return ar.mul(sign * base_sign + (1 - base_sign))
|
||||
# inject nan if the base is negative and the power is not an integer
|
||||
to_nan = (((x - x.trunc()) * 1e10).abs().clip(0, 1) if isinstance(x, Tensor) else int(bool(x - int(x))) if not reverse else ((self - self.trunc()) * 1e10).abs().clip(0, 1)) * base_sign
|
||||
inject_nan = ((((-to_nan) * 2) + 1)).log().add(1) if isinstance(to_nan, Tensor) else 1 if not to_nan else float("nan")
|
||||
return ar.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan)
|
||||
def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x)
|
||||
|
||||
def maximum(self, x:Union[Tensor, float]) -> Tensor: return self._broadcasted(mlops.Maximum, x)
|
||||
|
||||
Reference in New Issue
Block a user