support dtype in softmax and log_softmax (#6914)

matches torch. for mixed precision training, we would want to use float for softmax
This commit is contained in:
chenyu
2024-10-06 07:18:15 -04:00
committed by GitHub
parent 718b959349
commit 75d9dcf000
2 changed files with 26 additions and 6 deletions

View File

@@ -760,6 +760,25 @@ class TestAutoCastType(unittest.TestCase):
t.square().mean().backward()
np.testing.assert_allclose(t.grad.numpy().flatten(), [60000 * 2 / (N*N)] * N*N)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_softmax_dtype(self):
data = [1, 2, 3]
t = Tensor(data, dtype=dtypes.half)
tt = torch.tensor(data, dtype=torch.half)
out = t.softmax(0)
self.assertEqual(out.dtype, dtypes.half)
np.testing.assert_allclose(out.numpy(), tt.softmax(0).numpy(), rtol=1e-3)
out = t.softmax(0, dtype=dtypes.float)
self.assertEqual(out.dtype, dtypes.float)
np.testing.assert_allclose(out.numpy(), tt.softmax(0, dtype=torch.float).numpy(), rtol=1e-3)
out = t.log_softmax(0)
self.assertEqual(out.dtype, dtypes.half)
np.testing.assert_allclose(out.numpy(), tt.log_softmax(0).numpy(), rtol=1e-3)
out = t.log_softmax(0, dtype=dtypes.float)
self.assertEqual(out.dtype, dtypes.float)
np.testing.assert_allclose(out.numpy(), tt.log_softmax(0, dtype=torch.float).numpy(), rtol=1e-3)
class TestImplicitFunctionTypeChange(unittest.TestCase):
def test_functions(self):
result = []