mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
fix gradient of pow(t, int) (#9217)
semi revert some pow logic back to tensor. added direct gradient check because the backward in test_ops passed by luck
This commit is contained in:
@@ -629,6 +629,23 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op(None, lambda x: 0**x, vals=[[-2.,-1,0,1,2,3]])
|
||||
helper_test_op(None, lambda x: (-2)**x, vals=[[-2.,-1,0,1,2,3]])
|
||||
|
||||
def test_pow_const_direct(self):
|
||||
# x ** c
|
||||
def get_tiny_gradient(x, c):
|
||||
t = Tensor([x], dtype=dtypes.float)
|
||||
return (t ** c)[0].gradient(t)[0].item()
|
||||
def get_torch_gradient(x, c):
|
||||
t = torch.tensor([x], dtype=torch.float, requires_grad=True)
|
||||
return torch.autograd.grad(t ** c, t)[0].item()
|
||||
for x in [0, 1]:
|
||||
for c in [-1, 0, 0.3, 1, 2]:
|
||||
tiny_out = get_tiny_gradient(x, c)
|
||||
torch_out = get_torch_gradient(x, c)
|
||||
if math.isnan(tiny_out):
|
||||
assert math.isnan(torch_out)
|
||||
else:
|
||||
self.assertAlmostEqual(tiny_out, torch_out, msg=f"{x}, {c}")
|
||||
|
||||
def test_pow_zero_tensor(self):
|
||||
helper_test_op(None, lambda x,y: x**y, vals=[[0.0], [0.3]])
|
||||
helper_test_op(None, lambda x,y: x**y, vals=[[0.0], [0.0]])
|
||||
|
||||
@@ -3373,6 +3373,12 @@ class Tensor(SimpleMathTrait):
|
||||
print((2.0 ** Tensor([-1, 2, 3])).numpy())
|
||||
```
|
||||
"""
|
||||
# TODO: combine this with gradient
|
||||
if not reverse and isinstance(x, get_args(ConstType)) and int(x) == x:
|
||||
if x < 0: return self.reciprocal().pow(-x)
|
||||
if x == 0: return self*0+1
|
||||
return self.pow(int(x)//2).square() * (self if x%2 == 1 else 1)
|
||||
|
||||
base, exponent = self._broadcasted(x, reverse=reverse)
|
||||
# TODO: int pow
|
||||
if not base.is_floating_point(): raise RuntimeError("base needs to be float")
|
||||
|
||||
Reference in New Issue
Block a user