mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
onnx update for trilu and argmax (#3283)
* support 0 in shape for tril and triu * select_last_index for ArgMax and ArgMin * pass **kwargs
This commit is contained in:
6
test/external/external_test_onnx_backend.py
vendored
6
test/external/external_test_onnx_backend.py
vendored
@@ -81,8 +81,6 @@ backend_test.exclude('test_convinteger_*')
|
||||
backend_test.exclude('test_matmulinteger_*')
|
||||
|
||||
# we don't support indexes
|
||||
# backend_test.exclude('test_argmax_*') # Needs more work: select_last_index
|
||||
# backend_test.exclude('test_argmin_*') # Needs more work: select_last_index
|
||||
backend_test.exclude('test_nonzero_*')
|
||||
|
||||
# no support for mod
|
||||
@@ -128,10 +126,6 @@ backend_test.exclude('test_bitwise_*')
|
||||
backend_test.exclude('test_blackmanwindow_*')
|
||||
backend_test.exclude('test_bernoulli_*')
|
||||
backend_test.exclude('test_det_*')
|
||||
|
||||
backend_test.exclude('test_tril_zero_cpu') # TODO: zero array tril support
|
||||
backend_test.exclude('test_triu_zero_cpu') # TODO: zero array triu support
|
||||
|
||||
backend_test.exclude('test_col2im_*')
|
||||
backend_test.exclude('test_hammingwindow_*')
|
||||
backend_test.exclude('test_hannwindow_*')
|
||||
|
||||
@@ -252,12 +252,14 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(3,3)], lambda x: x.tril(1))
|
||||
helper_test_op([(3,3)], lambda x: x.tril(-1))
|
||||
helper_test_op([(5,3,3)], lambda x: x.tril())
|
||||
helper_test_op([(5,0,3)], lambda x: x.tril())
|
||||
helper_test_op([(5,3,3)], lambda x: x.tril(1))
|
||||
def test_triu(self):
|
||||
helper_test_op([(3,3)], lambda x: x.triu())
|
||||
helper_test_op([(3,3)], lambda x: x.triu(1))
|
||||
helper_test_op([(3,3)], lambda x: x.triu(-1))
|
||||
helper_test_op([(5,3,3)], lambda x: x.triu())
|
||||
helper_test_op([(5,0,3)], lambda x: x.triu())
|
||||
helper_test_op([(5,3,3)], lambda x: x.triu(1))
|
||||
def test_maximum(self):
|
||||
helper_test_op([(45,65), (45,65)], torch.maximum, Tensor.maximum)
|
||||
|
||||
Reference in New Issue
Block a user