don't round pow output for int pow int (#11625)

also added atol=0 and big pows for the tests
This commit is contained in:
chenyu
2025-08-11 17:57:47 -07:00
committed by GitHub
parent d623f6d850
commit 0c97d6de1b
2 changed files with 6 additions and 5 deletions

View File

@@ -700,10 +700,11 @@ class TestOps(unittest.TestCase):
helper_test_op(None, lambda x: x**-1.0, vals=[[-1.0, 0.0, 1.0]])
def test_int_pow_const_int(self):
helper_test_op(None, lambda x: x**0, vals=[[-2,0,2]], forward_only=True)
helper_test_op(None, lambda x: x**1, vals=[[-2,0,2]], forward_only=True)
helper_test_op(None, lambda x: x**2, vals=[[-2,0,2]], forward_only=True)
helper_test_op(None, lambda x: x**29, vals=[[-2,0,2]], forward_only=True)
helper_test_op(None, lambda x: x**0, vals=[[-2,0,2]], forward_only=True, atol=0)
helper_test_op(None, lambda x: x**1, vals=[[-2,0,2]], forward_only=True, atol=0)
helper_test_op(None, lambda x: x**2, vals=[[-2,0,2]], forward_only=True, atol=0)
helper_test_op(None, lambda x: x**7, vals=[[11,12,13]], forward_only=True, atol=0)
helper_test_op(None, lambda x: x**29, vals=[[-2,0,2]], forward_only=True, atol=0)
self.helper_test_exception(None, lambda x: x**-2, vals=[[-2,0,2]], forward_only=True, expected=RuntimeError)
@unittest.skip("not supported")

View File

@@ -3684,7 +3684,7 @@ class Tensor(MathTrait):
ret = base._apply_uop(UOp.pow, exponent)
# NOTE: pow(int, float) -> int
return ret.round().cast(self.dtype) if not reverse and not dtypes.is_float(self.dtype) else ret
return ret.round().cast(self.dtype) if not reverse and not dtypes.is_float(self.dtype) and dtypes.is_float(exponent.dtype) else ret
def maximum(self, x:Tensor|ConstType) -> Tensor:
"""