mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 06:48:22 -05:00
Add test for .softmax.argmax (#3559)
* Add broken test for known issue * skip PYTHON * skip PYTHON * fix commit --------- Co-authored-by: schlimeszn <schlimeszn@gmail.com> Co-authored-by: reddyn <nikidsniper@gmail.com>
This commit is contained in:
@@ -752,7 +752,10 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(10,10,10)], lambda x: x.softmax(0), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(10,10,10)], lambda x: x.softmax(1), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(10,10,10)], lambda x: x.softmax(2), atol=1e-7, grad_atol=1e-7)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ["CLANG", "PYTHON"], "Broken ISSUE #3552")
|
||||
def test_softmax_argmax(self):
|
||||
helper_test_op([(45,65)], lambda x: x.softmax(0).argmax(), forward_only=True, atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(45,65)], lambda x: x.softmax(1).argmax(), forward_only=True, atol=1e-7, grad_atol=1e-7)
|
||||
def test_log_softmax(self):
|
||||
helper_test_op([(45,65)], torch.nn.LogSoftmax(dim=1), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(45)], torch.nn.LogSoftmax(dim=0), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
|
||||
|
||||
Reference in New Issue
Block a user