fix softmax and log_softmax for 0d tensor (#3463)

matched torch to take axis \in [-1, 0] and used axis=None internally
This commit is contained in:
chenyu
2024-02-21 11:30:30 -05:00
committed by GitHub
parent 871ba73e65
commit 1eb24af63b
2 changed files with 12 additions and 0 deletions

View File

@@ -746,9 +746,18 @@ class TestOps(unittest.TestCase):
forward_only = (Device.DEFAULT == "WEBGPU")
helper_test_op([(45,65)], torch.nn.Softmax(dim=1), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
helper_test_op([(45)], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
helper_test_op([()], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
helper_test_op([()], torch.nn.Softmax(dim=-1), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
def test_softmax_other_axis(self):
helper_test_op([(10,10,10)], lambda x: x.softmax(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([(10,10,10)], lambda x: x.softmax(1), atol=1e-7, grad_atol=1e-7)
helper_test_op([(10,10,10)], lambda x: x.softmax(2), atol=1e-7, grad_atol=1e-7)
def test_log_softmax(self):
helper_test_op([(45,65)], torch.nn.LogSoftmax(dim=1), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
helper_test_op([(45)], torch.nn.LogSoftmax(dim=0), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
helper_test_op([()], torch.nn.LogSoftmax(dim=0), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
helper_test_op([()], torch.nn.LogSoftmax(dim=-1), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
def test_log_softmax_other_axis(self):
helper_test_op([(10,10,10)], lambda x: x.log_softmax(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([(10,10,10)], lambda x: x.log_softmax(1), atol=1e-7, grad_atol=1e-7)