Tensor.sum returns in acc_dtype if specified (#5012)

* Tensor.sum returns in acc_dtype if specified

* skip PYTHON for now

* revert that

* relax that
This commit is contained in:
chenyu
2024-06-17 16:35:52 -04:00
committed by GitHub
parent 013c73c3b3
commit 4296507021
3 changed files with 16 additions and 3 deletions

View File

@@ -1,4 +1,4 @@
import unittest, operator, subprocess
import unittest, operator, subprocess, math
import numpy as np
import torch
from typing import Any, List
@@ -544,6 +544,17 @@ class TestAutoCastType(unittest.TestCase):
assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64
@unittest.skipIf(Device.DEFAULT == "PYTHON", "TODO: support inf to half in PYTHON backend")
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16")
def test_sum_acc_dtype(self):
t = Tensor([40000, 40000], dtype=dtypes.float16)
# default float16 sum returns in float16, overflowed in this case
assert t.sum().dtype == dtypes.float16
assert math.isinf(t.sum().numpy().item())
# specifiying acc_dtype and it's not downcasted
assert t.sum(acc_dtype=dtypes.float32).dtype == dtypes.float32
np.testing.assert_allclose(t.sum(acc_dtype=dtypes.float32).numpy(), 80000)
def test_mean(self):
assert (Tensor([0, 1], dtype=dtypes.bool)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.int8)).mean().dtype == dtypes.float32

View File

@@ -352,7 +352,7 @@ class TestNN(unittest.TestCase):
z.sum().backward()
torch_z.sum().backward(retain_graph=True)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=2e-3, rtol=1e-3)
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
def test_embedding(self):

View File

@@ -1285,7 +1285,8 @@ class Tensor:
```
"""
ret = self.cast(acc_dtype or sum_acc_dtype(self.dtype))._reduce(F.Sum, axis, keepdim)
return ret.cast(self.dtype) if self.dtype in {dtypes.float16, dtypes.bfloat16} else ret
return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
def max(self, axis=None, keepdim=False):
"""
Returns the maximum value of the tensor along the specified axis or axes.
@@ -1308,6 +1309,7 @@ class Tensor:
```
"""
return self._reduce(F.Max, axis, keepdim)
def min(self, axis=None, keepdim=False):
"""
Returns the minimum value of the tensor along the specified axis or axes.