From c22e77abfdb55f8248db852126737fbecfd52a7b Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Wed, 26 Jul 2023 22:14:54 -0400 Subject: [PATCH] Match torch on fractional negative base pow (#1352) * feat: match torch on fractional negative base pow * feat: tests for trunc --- test/test_ops.py | 9 +++++++++ tinygrad/tensor.py | 14 +++++++------- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 74aacc9014..c3ae86c2b7 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -182,6 +182,10 @@ class TestOps(unittest.TestCase): tt2 = Tensor.ones(4, requires_grad=True) self.assertRaises(RuntimeError, (tt1 < tt2).sum().backward) + def test_trunc(self): + helper_test_op([(45,65)], lambda x: torch.trunc(x), lambda x: x.trunc(), forward_only=True) + a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor([1.0, 2.1, 0.0, -5.0, -2.5]) + helper_test_op([], lambda: torch.trunc(b), lambda: Tensor.trunc(a), forward_only=True) def test_floor(self): helper_test_op([(45,65)], lambda x: torch.floor(x), lambda x: x.floor(), forward_only=True) a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor([1.0, 2.1, 0.0, -5.0, -2.5]) @@ -271,6 +275,11 @@ class TestOps(unittest.TestCase): # Regression tests for https://github.com/tinygrad/tinygrad/issues/1151 helper_test_op([(45,65)], lambda x: x**3, lambda x: Tensor.pow(x,3), a=-10) helper_test_op([()], lambda x: x**3, lambda x: Tensor.pow(x,3), a=-10) + # Regression tests for https://github.com/tinygrad/tinygrad/issues/1251 + helper_test_op([(45,65)], lambda x: x**0.2, lambda x: Tensor.pow(x,0.2), a=-10) + helper_test_op([(45,65)], lambda x: x**1.2, lambda x: Tensor.pow(x,1.2), a=-10) + helper_test_op([()], lambda x: x**0.2, lambda x: Tensor.pow(x,0.2), a=-10) + helper_test_op([()], lambda x: x**1.2, lambda x: Tensor.pow(x,1.2), a=-10) def test_pow_const(self): helper_test_op([(45,65)], lambda x: x**1.0, lambda x: x**1.0) helper_test_op([(45,65)], lambda x: x**-1.0, lambda x: x**-1.0) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 838c782819..1d31a74eed 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -513,12 +513,9 @@ class Tensor: def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype).where(Tensor.zeros_like(self), self) # ***** math functions (unary) ***** - def ceil(self: Tensor) -> Tensor: - b = self.cast(dtypes.int32).contiguous().cast(self.dtype) - return (self > b).where(b+1, b) - def floor(self: Tensor) -> Tensor: - b = self.cast(dtypes.int32).contiguous().cast(self.dtype) - return (self < b).where(b-1, b) + def trunc(self: Tensor) -> Tensor: return self.cast(dtypes.int32).contiguous().cast(self.dtype) + def ceil(self: Tensor) -> Tensor: return (self > (b := self.trunc())).where(b+1, b) + def floor(self: Tensor) -> Tensor: return (self < (b := self.trunc())).where(b-1, b) def __neg__(self): return 0.0-self def square(self): return self*self @@ -580,7 +577,10 @@ class Tensor: sign = (x * pi).cos() if isinstance(x, Tensor) else cos(x * pi) if not reverse else (self * pi).cos() # we only need to correct the sign if the base is negative base_sign = ((self.sign() if not reverse else x.sign() if isinstance(x, Tensor) else copysign(1, x)) - 1) / -2 - return ar.mul(sign * base_sign + (1 - base_sign)) + # inject nan if the base is negative and the power is not an integer + to_nan = (((x - x.trunc()) * 1e10).abs().clip(0, 1) if isinstance(x, Tensor) else int(bool(x - int(x))) if not reverse else ((self - self.trunc()) * 1e10).abs().clip(0, 1)) * base_sign + inject_nan = ((((-to_nan) * 2) + 1)).log().add(1) if isinstance(to_nan, Tensor) else 1 if not to_nan else float("nan") + return ar.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan) def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x) def maximum(self, x:Union[Tensor, float]) -> Tensor: return self._broadcasted(mlops.Maximum, x)