_reduce_op is axis based now (#3462)

* _reduce_op is axis based now

* axis_

* update lin failures

* disable that

* fix shape
This commit is contained in:
George Hotz
2024-02-21 16:36:31 +01:00
committed by GitHub
parent 22a90cbb15
commit 871ba73e65
8 changed files with 38 additions and 29 deletions

View File

@@ -745,10 +745,10 @@ class TestOps(unittest.TestCase):
# exceed per kernel buffer limit with backward
forward_only = (Device.DEFAULT == "WEBGPU")
helper_test_op([(45,65)], torch.nn.Softmax(dim=1), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
helper_test_op([()], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
helper_test_op([(45)], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
def test_log_softmax(self):
helper_test_op([(45,65)], torch.nn.LogSoftmax(dim=1), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
helper_test_op([()], torch.nn.LogSoftmax(dim=0), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
helper_test_op([(45)], torch.nn.LogSoftmax(dim=0), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
def test_log_softmax_other_axis(self):
helper_test_op([(10,10,10)], lambda x: x.log_softmax(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([(10,10,10)], lambda x: x.log_softmax(1), atol=1e-7, grad_atol=1e-7)