fix gradient of pow(t, int) (#9217)

semi revert some pow logic back to tensor. added direct gradient check because the backward in test_ops passed by luck
This commit is contained in:
chenyu
2025-02-23 17:42:09 -05:00
committed by GitHub
parent 12b5b83821
commit b3ae664d5d
2 changed files with 23 additions and 0 deletions

View File

@@ -629,6 +629,23 @@ class TestOps(unittest.TestCase):
helper_test_op(None, lambda x: 0**x, vals=[[-2.,-1,0,1,2,3]])
helper_test_op(None, lambda x: (-2)**x, vals=[[-2.,-1,0,1,2,3]])
def test_pow_const_direct(self):
# x ** c
def get_tiny_gradient(x, c):
t = Tensor([x], dtype=dtypes.float)
return (t ** c)[0].gradient(t)[0].item()
def get_torch_gradient(x, c):
t = torch.tensor([x], dtype=torch.float, requires_grad=True)
return torch.autograd.grad(t ** c, t)[0].item()
for x in [0, 1]:
for c in [-1, 0, 0.3, 1, 2]:
tiny_out = get_tiny_gradient(x, c)
torch_out = get_torch_gradient(x, c)
if math.isnan(tiny_out):
assert math.isnan(torch_out)
else:
self.assertAlmostEqual(tiny_out, torch_out, msg=f"{x}, {c}")
def test_pow_zero_tensor(self):
helper_test_op(None, lambda x,y: x**y, vals=[[0.0], [0.3]])
helper_test_op(None, lambda x,y: x**y, vals=[[0.0], [0.0]])

View File

@@ -3373,6 +3373,12 @@ class Tensor(SimpleMathTrait):
print((2.0 ** Tensor([-1, 2, 3])).numpy())
```
"""
# TODO: combine this with gradient
if not reverse and isinstance(x, get_args(ConstType)) and int(x) == x:
if x < 0: return self.reciprocal().pow(-x)
if x == 0: return self*0+1
return self.pow(int(x)//2).square() * (self if x%2 == 1 else 1)
base, exponent = self._broadcasted(x, reverse=reverse)
# TODO: int pow
if not base.is_floating_point(): raise RuntimeError("base needs to be float")