mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
don't round pow output for int pow int (#11625)
also added atol=0 and big pows for the tests
This commit is contained in:
@@ -700,10 +700,11 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op(None, lambda x: x**-1.0, vals=[[-1.0, 0.0, 1.0]])
|
||||
|
||||
def test_int_pow_const_int(self):
|
||||
helper_test_op(None, lambda x: x**0, vals=[[-2,0,2]], forward_only=True)
|
||||
helper_test_op(None, lambda x: x**1, vals=[[-2,0,2]], forward_only=True)
|
||||
helper_test_op(None, lambda x: x**2, vals=[[-2,0,2]], forward_only=True)
|
||||
helper_test_op(None, lambda x: x**29, vals=[[-2,0,2]], forward_only=True)
|
||||
helper_test_op(None, lambda x: x**0, vals=[[-2,0,2]], forward_only=True, atol=0)
|
||||
helper_test_op(None, lambda x: x**1, vals=[[-2,0,2]], forward_only=True, atol=0)
|
||||
helper_test_op(None, lambda x: x**2, vals=[[-2,0,2]], forward_only=True, atol=0)
|
||||
helper_test_op(None, lambda x: x**7, vals=[[11,12,13]], forward_only=True, atol=0)
|
||||
helper_test_op(None, lambda x: x**29, vals=[[-2,0,2]], forward_only=True, atol=0)
|
||||
self.helper_test_exception(None, lambda x: x**-2, vals=[[-2,0,2]], forward_only=True, expected=RuntimeError)
|
||||
|
||||
@unittest.skip("not supported")
|
||||
|
||||
@@ -3684,7 +3684,7 @@ class Tensor(MathTrait):
|
||||
|
||||
ret = base._apply_uop(UOp.pow, exponent)
|
||||
# NOTE: pow(int, float) -> int
|
||||
return ret.round().cast(self.dtype) if not reverse and not dtypes.is_float(self.dtype) else ret
|
||||
return ret.round().cast(self.dtype) if not reverse and not dtypes.is_float(self.dtype) and dtypes.is_float(exponent.dtype) else ret
|
||||
|
||||
def maximum(self, x:Tensor|ConstType) -> Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user