diff --git a/test/test_ops.py b/test/test_ops.py index df2feb9512..01dc38cf9b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -699,6 +699,13 @@ class TestOps(unittest.TestCase): helper_test_op(None, lambda x: x**-0.3, vals=[[0.0]]) helper_test_op(None, lambda x: x**-1.0, vals=[[-1.0, 0.0, 1.0]]) + def test_int_pow_const_int(self): + helper_test_op(None, lambda x: x**0, vals=[[-2,0,2]], forward_only=True) + helper_test_op(None, lambda x: x**1, vals=[[-2,0,2]], forward_only=True) + helper_test_op(None, lambda x: x**2, vals=[[-2,0,2]], forward_only=True) + helper_test_op(None, lambda x: x**29, vals=[[-2,0,2]], forward_only=True) + self.helper_test_exception(None, lambda x: x**-2, vals=[[-2,0,2]], forward_only=True, expected=RuntimeError) + @unittest.skip("not supported") def test_pow_int(self): def _test(base, exponent): helper_test_op(None, lambda x,y: x**y, vals=[base, exponent], forward_only=True) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 7854434a4c..0734f07489 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3680,7 +3680,7 @@ class Tensor(MathTrait): """ base, exponent = self._broadcasted(x, reverse=reverse) # TODO: int pow - if not base.is_floating_point(): raise RuntimeError("base needs to be float") + if not base.is_floating_point() and not (isinstance(x, int) and x >= 0): raise RuntimeError("base needs to be float") ret = base._apply_uop(UOp.pow, exponent) # NOTE: pow(int, float) -> int