fix pow of int to negative const int (#8129)

it should return in int
This commit is contained in:
chenyu
2024-12-09 17:20:18 -05:00
committed by GitHub
parent 12f7d284e0
commit 358287959b
2 changed files with 6 additions and 1 deletions

View File

@@ -582,6 +582,10 @@ class TestOps(unittest.TestCase):
with self.assertRaises(AssertionError): np.testing.assert_equal((Tensor(11) ** Tensor([7])).item(), 11 ** 7)
with self.assertRaises(AssertionError): np.testing.assert_equal((Tensor([11]) ** Tensor([7])).item(), 11 ** 7)
# pow to a const int
helper_test_op([], lambda: torch.tensor([2], dtype=torch.int) ** torch.tensor(-2, dtype=torch.int),
lambda: Tensor([2]) ** Tensor(-2), forward_only=True)
def test_sqrt(self):
helper_test_op([(45,65)], lambda x: x.sqrt())
helper_test_op([()], lambda x: x.sqrt())

View File

@@ -3196,8 +3196,9 @@ class Tensor(SimpleMathTrait):
x = self._to_const_val(x)
if not isinstance(x, Tensor) and not reverse:
# simple pow identities
if x < 0: return self.reciprocal().pow(-x)
if x < 0: return self.reciprocal().pow(-x).cast(self.dtype)
if x == 0: return 1 + self * 0
# rewrite pow 0.5 to sqrt
if int(x - 0.5) + 0.5 == x: return self.pow(int(x - 0.5)) * self.sqrt()
if int(x) == x: return self.pow(x // 2).square() * (1 if x % 2 == 0 else self)