mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
support int Tensor pow to const non-negative int (#11624)
matches torch
This commit is contained in:
@@ -699,6 +699,13 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op(None, lambda x: x**-0.3, vals=[[0.0]])
|
||||
helper_test_op(None, lambda x: x**-1.0, vals=[[-1.0, 0.0, 1.0]])
|
||||
|
||||
def test_int_pow_const_int(self):
|
||||
helper_test_op(None, lambda x: x**0, vals=[[-2,0,2]], forward_only=True)
|
||||
helper_test_op(None, lambda x: x**1, vals=[[-2,0,2]], forward_only=True)
|
||||
helper_test_op(None, lambda x: x**2, vals=[[-2,0,2]], forward_only=True)
|
||||
helper_test_op(None, lambda x: x**29, vals=[[-2,0,2]], forward_only=True)
|
||||
self.helper_test_exception(None, lambda x: x**-2, vals=[[-2,0,2]], forward_only=True, expected=RuntimeError)
|
||||
|
||||
@unittest.skip("not supported")
|
||||
def test_pow_int(self):
|
||||
def _test(base, exponent): helper_test_op(None, lambda x,y: x**y, vals=[base, exponent], forward_only=True)
|
||||
|
||||
Reference in New Issue
Block a user