mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
_reduce_op is axis based now (#3462)
* _reduce_op is axis based now * axis_ * update lin failures * disable that * fix shape
This commit is contained in:
@@ -745,10 +745,10 @@ class TestOps(unittest.TestCase):
|
||||
# exceed per kernel buffer limit with backward
|
||||
forward_only = (Device.DEFAULT == "WEBGPU")
|
||||
helper_test_op([(45,65)], torch.nn.Softmax(dim=1), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
|
||||
helper_test_op([()], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
|
||||
helper_test_op([(45)], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
|
||||
def test_log_softmax(self):
|
||||
helper_test_op([(45,65)], torch.nn.LogSoftmax(dim=1), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([()], torch.nn.LogSoftmax(dim=0), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(45)], torch.nn.LogSoftmax(dim=0), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
|
||||
def test_log_softmax_other_axis(self):
|
||||
helper_test_op([(10,10,10)], lambda x: x.log_softmax(0), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(10,10,10)], lambda x: x.log_softmax(1), atol=1e-7, grad_atol=1e-7)
|
||||
|
||||
Reference in New Issue
Block a user