From aa4a0de287522f965e0962457a1bba64948c7d1b Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 13 Dec 2023 11:39:20 -0500 Subject: [PATCH] simpler Tensor.pow to integer (#2746) --- test/test_ops.py | 2 ++ test/test_schedule.py | 9 ++++++++- tinygrad/tensor.py | 6 ++---- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index feca869f4b..240eb3eba1 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -300,6 +300,8 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow, a=0) def test_pow(self): # TODO: why is a=0 for these tests? + helper_test_op([(45,65)], lambda x: x**0, lambda x: Tensor.pow(x,0), a=0) + helper_test_op([(45,65)], lambda x: x**1, lambda x: Tensor.pow(x,1), a=0) helper_test_op([(45,65)], lambda x: x**2, lambda x: Tensor.pow(x,2), a=0) helper_test_op([(45,65)], lambda x: x**3, lambda x: Tensor.pow(x,3), a=0) helper_test_op([(45,65)], lambda x: x**-2, lambda x: Tensor.pow(x,-2), a=0) diff --git a/test/test_schedule.py b/test/test_schedule.py index 952f996749..f1a2fe9418 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -327,11 +327,18 @@ class TestSchedule(unittest.TestCase): out = x.to('cpu') check_schedule(out, 0, filter_loadops=False) - def test_pow_const_tensor(self): + def test_pow_const_tensor_simplified(self): x = Tensor([1,2,3,4]) + # NOTE: this does not test ** Tensor(2) is simpler in ast than ** Tensor(2.5) out = x ** Tensor(2) check_schedule(out, 1) + def test_pow_const_tensor_to_zero(self): + x = Tensor([1,2,3,4]) + out = x ** Tensor(0) + # NOTE: this is ConstBuffer 0 + ConstBuffer 1 + check_schedule(out, 1) + def test_zero_size(self): x = Tensor.rand(2, 3, 0) out = x + 1 diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index e64a19c753..43e41e79d9 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -763,12 +763,10 @@ class Tensor: return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x or not dtypes.is_float(self.dtype) else self.mul(1/x) # noqa: E501 def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor: x = self._to_float(x) - if x.__class__ is not Tensor and not reverse: + if not isinstance(x, Tensor) and not reverse: # simple pow identities if x < 0: return self.reciprocal().pow(-x) - if x == 3.0: return self*self*self - if x == 2.0: return self*self - if x == 1.0: return self + if x in [3,2,1,0]: return reduce(lambda acc,_: acc * self, range(int(x)), mlops.Zero.apply(self)+1) if x == 0.5: return self.sqrt() if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp() ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(math.log(abs(x))).exp()