mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
support dtype in softmax and log_softmax (#6914)
matches torch. for mixed precision training, we would want to use float for softmax
This commit is contained in:
@@ -760,6 +760,25 @@ class TestAutoCastType(unittest.TestCase):
|
||||
t.square().mean().backward()
|
||||
np.testing.assert_allclose(t.grad.numpy().flatten(), [60000 * 2 / (N*N)] * N*N)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_softmax_dtype(self):
|
||||
data = [1, 2, 3]
|
||||
t = Tensor(data, dtype=dtypes.half)
|
||||
tt = torch.tensor(data, dtype=torch.half)
|
||||
|
||||
out = t.softmax(0)
|
||||
self.assertEqual(out.dtype, dtypes.half)
|
||||
np.testing.assert_allclose(out.numpy(), tt.softmax(0).numpy(), rtol=1e-3)
|
||||
out = t.softmax(0, dtype=dtypes.float)
|
||||
self.assertEqual(out.dtype, dtypes.float)
|
||||
np.testing.assert_allclose(out.numpy(), tt.softmax(0, dtype=torch.float).numpy(), rtol=1e-3)
|
||||
out = t.log_softmax(0)
|
||||
self.assertEqual(out.dtype, dtypes.half)
|
||||
np.testing.assert_allclose(out.numpy(), tt.log_softmax(0).numpy(), rtol=1e-3)
|
||||
out = t.log_softmax(0, dtype=dtypes.float)
|
||||
self.assertEqual(out.dtype, dtypes.float)
|
||||
np.testing.assert_allclose(out.numpy(), tt.log_softmax(0, dtype=torch.float).numpy(), rtol=1e-3)
|
||||
|
||||
class TestImplicitFunctionTypeChange(unittest.TestCase):
|
||||
def test_functions(self):
|
||||
result = []
|
||||
|
||||
Reference in New Issue
Block a user