mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
check Tensor.permute input arg is a valid permutation (#5069)
also added support of negative axes
This commit is contained in:
@@ -1112,9 +1112,15 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(3,3)], lambda x: x.T)
|
||||
helper_test_op([(3,3,3)], lambda x: x.transpose(1,2))
|
||||
helper_test_op([(3,3,3)], lambda x: x.transpose(0,2))
|
||||
|
||||
def test_permute(self):
|
||||
helper_test_op([(1,2,3,4)], lambda x: x.permute((3,0,2,1)))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.permute((3,2,1,0)))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.permute((-2,-1,1,0)))
|
||||
helper_test_op([()], lambda x: x.permute(()))
|
||||
self.helper_test_exception([(3,4,5,6)], lambda x: x.permute((0,2)), lambda x: x.permute((0,2)), expected=RuntimeError)
|
||||
self.helper_test_exception([(3,4,5,6)], lambda x: x.permute((0,1,2,3,3,3)), lambda x: x.permute((0,1,2,3,3,3)), expected=RuntimeError)
|
||||
self.helper_test_exception([(3,4,5,6)], lambda x: x.permute((0,0,1,2,3)), lambda x: x.permute((0,0,1,2,3)), expected=RuntimeError)
|
||||
|
||||
def test_reshape(self):
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.reshape((-1,3,6,6)))
|
||||
|
||||
Reference in New Issue
Block a user