diff --git a/test/test_dtype.py b/test/test_dtype.py index cbc26c05e2..6aaa5eaa18 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -1,4 +1,4 @@ -import unittest, operator, subprocess +import unittest, operator, subprocess, math import numpy as np import torch from typing import Any, List @@ -544,6 +544,17 @@ class TestAutoCastType(unittest.TestCase): assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32 assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64 + @unittest.skipIf(Device.DEFAULT == "PYTHON", "TODO: support inf to half in PYTHON backend") + @unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16") + def test_sum_acc_dtype(self): + t = Tensor([40000, 40000], dtype=dtypes.float16) + # default float16 sum returns in float16, overflowed in this case + assert t.sum().dtype == dtypes.float16 + assert math.isinf(t.sum().numpy().item()) + # specifiying acc_dtype and it's not downcasted + assert t.sum(acc_dtype=dtypes.float32).dtype == dtypes.float32 + np.testing.assert_allclose(t.sum(acc_dtype=dtypes.float32).numpy(), 80000) + def test_mean(self): assert (Tensor([0, 1], dtype=dtypes.bool)).mean().dtype == dtypes.float32 assert (Tensor([0, 1], dtype=dtypes.int8)).mean().dtype == dtypes.float32 diff --git a/test/test_nn.py b/test/test_nn.py index 6a8d461092..81a20c35a0 100755 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -352,7 +352,7 @@ class TestNN(unittest.TestCase): z.sum().backward() torch_z.sum().backward(retain_graph=True) np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3) - np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=1e-3, rtol=1e-3) + np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=2e-3, rtol=1e-3) np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3) def test_embedding(self): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index e73a68ec2a..9242779bd5 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1285,7 +1285,8 @@ class Tensor: ``` """ ret = self.cast(acc_dtype or sum_acc_dtype(self.dtype))._reduce(F.Sum, axis, keepdim) - return ret.cast(self.dtype) if self.dtype in {dtypes.float16, dtypes.bfloat16} else ret + return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret + def max(self, axis=None, keepdim=False): """ Returns the maximum value of the tensor along the specified axis or axes. @@ -1308,6 +1309,7 @@ class Tensor: ``` """ return self._reduce(F.Max, axis, keepdim) + def min(self, axis=None, keepdim=False): """ Returns the minimum value of the tensor along the specified axis or axes.