onnx update for trilu and argmax (#3283)

* support 0 in shape for tril and triu

* select_last_index for ArgMax and ArgMin

* pass **kwargs
This commit is contained in:
chenyu
2024-01-30 18:39:16 -05:00
committed by GitHub
parent 5b46b0ff3d
commit 7816c3b692
4 changed files with 5 additions and 10 deletions

View File

@@ -81,8 +81,6 @@ backend_test.exclude('test_convinteger_*')
backend_test.exclude('test_matmulinteger_*')
# we don't support indexes
# backend_test.exclude('test_argmax_*') # Needs more work: select_last_index
# backend_test.exclude('test_argmin_*') # Needs more work: select_last_index
backend_test.exclude('test_nonzero_*')
# no support for mod
@@ -128,10 +126,6 @@ backend_test.exclude('test_bitwise_*')
backend_test.exclude('test_blackmanwindow_*')
backend_test.exclude('test_bernoulli_*')
backend_test.exclude('test_det_*')
backend_test.exclude('test_tril_zero_cpu') # TODO: zero array tril support
backend_test.exclude('test_triu_zero_cpu') # TODO: zero array triu support
backend_test.exclude('test_col2im_*')
backend_test.exclude('test_hammingwindow_*')
backend_test.exclude('test_hannwindow_*')

View File

@@ -252,12 +252,14 @@ class TestOps(unittest.TestCase):
helper_test_op([(3,3)], lambda x: x.tril(1))
helper_test_op([(3,3)], lambda x: x.tril(-1))
helper_test_op([(5,3,3)], lambda x: x.tril())
helper_test_op([(5,0,3)], lambda x: x.tril())
helper_test_op([(5,3,3)], lambda x: x.tril(1))
def test_triu(self):
helper_test_op([(3,3)], lambda x: x.triu())
helper_test_op([(3,3)], lambda x: x.triu(1))
helper_test_op([(3,3)], lambda x: x.triu(-1))
helper_test_op([(5,3,3)], lambda x: x.triu())
helper_test_op([(5,0,3)], lambda x: x.triu())
helper_test_op([(5,3,3)], lambda x: x.triu(1))
def test_maximum(self):
helper_test_op([(45,65), (45,65)], torch.maximum, Tensor.maximum)