mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
@@ -582,6 +582,10 @@ class TestOps(unittest.TestCase):
|
||||
with self.assertRaises(AssertionError): np.testing.assert_equal((Tensor(11) ** Tensor([7])).item(), 11 ** 7)
|
||||
with self.assertRaises(AssertionError): np.testing.assert_equal((Tensor([11]) ** Tensor([7])).item(), 11 ** 7)
|
||||
|
||||
# pow to a const int
|
||||
helper_test_op([], lambda: torch.tensor([2], dtype=torch.int) ** torch.tensor(-2, dtype=torch.int),
|
||||
lambda: Tensor([2]) ** Tensor(-2), forward_only=True)
|
||||
|
||||
def test_sqrt(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sqrt())
|
||||
helper_test_op([()], lambda x: x.sqrt())
|
||||
|
||||
@@ -3196,8 +3196,9 @@ class Tensor(SimpleMathTrait):
|
||||
x = self._to_const_val(x)
|
||||
if not isinstance(x, Tensor) and not reverse:
|
||||
# simple pow identities
|
||||
if x < 0: return self.reciprocal().pow(-x)
|
||||
if x < 0: return self.reciprocal().pow(-x).cast(self.dtype)
|
||||
if x == 0: return 1 + self * 0
|
||||
# rewrite pow 0.5 to sqrt
|
||||
if int(x - 0.5) + 0.5 == x: return self.pow(int(x - 0.5)) * self.sqrt()
|
||||
if int(x) == x: return self.pow(x // 2).square() * (1 if x % 2 == 0 else self)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user