mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
don't round pow output for int pow int (#11625)
also added atol=0 and big pows for the tests
This commit is contained in:
@@ -700,10 +700,11 @@ class TestOps(unittest.TestCase):
|
|||||||
helper_test_op(None, lambda x: x**-1.0, vals=[[-1.0, 0.0, 1.0]])
|
helper_test_op(None, lambda x: x**-1.0, vals=[[-1.0, 0.0, 1.0]])
|
||||||
|
|
||||||
def test_int_pow_const_int(self):
|
def test_int_pow_const_int(self):
|
||||||
helper_test_op(None, lambda x: x**0, vals=[[-2,0,2]], forward_only=True)
|
helper_test_op(None, lambda x: x**0, vals=[[-2,0,2]], forward_only=True, atol=0)
|
||||||
helper_test_op(None, lambda x: x**1, vals=[[-2,0,2]], forward_only=True)
|
helper_test_op(None, lambda x: x**1, vals=[[-2,0,2]], forward_only=True, atol=0)
|
||||||
helper_test_op(None, lambda x: x**2, vals=[[-2,0,2]], forward_only=True)
|
helper_test_op(None, lambda x: x**2, vals=[[-2,0,2]], forward_only=True, atol=0)
|
||||||
helper_test_op(None, lambda x: x**29, vals=[[-2,0,2]], forward_only=True)
|
helper_test_op(None, lambda x: x**7, vals=[[11,12,13]], forward_only=True, atol=0)
|
||||||
|
helper_test_op(None, lambda x: x**29, vals=[[-2,0,2]], forward_only=True, atol=0)
|
||||||
self.helper_test_exception(None, lambda x: x**-2, vals=[[-2,0,2]], forward_only=True, expected=RuntimeError)
|
self.helper_test_exception(None, lambda x: x**-2, vals=[[-2,0,2]], forward_only=True, expected=RuntimeError)
|
||||||
|
|
||||||
@unittest.skip("not supported")
|
@unittest.skip("not supported")
|
||||||
|
|||||||
@@ -3684,7 +3684,7 @@ class Tensor(MathTrait):
|
|||||||
|
|
||||||
ret = base._apply_uop(UOp.pow, exponent)
|
ret = base._apply_uop(UOp.pow, exponent)
|
||||||
# NOTE: pow(int, float) -> int
|
# NOTE: pow(int, float) -> int
|
||||||
return ret.round().cast(self.dtype) if not reverse and not dtypes.is_float(self.dtype) else ret
|
return ret.round().cast(self.dtype) if not reverse and not dtypes.is_float(self.dtype) and dtypes.is_float(exponent.dtype) else ret
|
||||||
|
|
||||||
def maximum(self, x:Tensor|ConstType) -> Tensor:
|
def maximum(self, x:Tensor|ConstType) -> Tensor:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user