mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Tensor.sum returns in acc_dtype if specified (#5012)
* Tensor.sum returns in acc_dtype if specified * skip PYTHON for now * revert that * relax that
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import unittest, operator, subprocess
|
||||
import unittest, operator, subprocess, math
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import Any, List
|
||||
@@ -544,6 +544,17 @@ class TestAutoCastType(unittest.TestCase):
|
||||
assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "PYTHON", "TODO: support inf to half in PYTHON backend")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16")
|
||||
def test_sum_acc_dtype(self):
|
||||
t = Tensor([40000, 40000], dtype=dtypes.float16)
|
||||
# default float16 sum returns in float16, overflowed in this case
|
||||
assert t.sum().dtype == dtypes.float16
|
||||
assert math.isinf(t.sum().numpy().item())
|
||||
# specifiying acc_dtype and it's not downcasted
|
||||
assert t.sum(acc_dtype=dtypes.float32).dtype == dtypes.float32
|
||||
np.testing.assert_allclose(t.sum(acc_dtype=dtypes.float32).numpy(), 80000)
|
||||
|
||||
def test_mean(self):
|
||||
assert (Tensor([0, 1], dtype=dtypes.bool)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int8)).mean().dtype == dtypes.float32
|
||||
|
||||
@@ -352,7 +352,7 @@ class TestNN(unittest.TestCase):
|
||||
z.sum().backward()
|
||||
torch_z.sum().backward(retain_graph=True)
|
||||
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
||||
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
||||
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=2e-3, rtol=1e-3)
|
||||
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
||||
|
||||
def test_embedding(self):
|
||||
|
||||
@@ -1285,7 +1285,8 @@ class Tensor:
|
||||
```
|
||||
"""
|
||||
ret = self.cast(acc_dtype or sum_acc_dtype(self.dtype))._reduce(F.Sum, axis, keepdim)
|
||||
return ret.cast(self.dtype) if self.dtype in {dtypes.float16, dtypes.bfloat16} else ret
|
||||
return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
|
||||
|
||||
def max(self, axis=None, keepdim=False):
|
||||
"""
|
||||
Returns the maximum value of the tensor along the specified axis or axes.
|
||||
@@ -1308,6 +1309,7 @@ class Tensor:
|
||||
```
|
||||
"""
|
||||
return self._reduce(F.Max, axis, keepdim)
|
||||
|
||||
def min(self, axis=None, keepdim=False):
|
||||
"""
|
||||
Returns the minimum value of the tensor along the specified axis or axes.
|
||||
|
||||
Reference in New Issue
Block a user