From 7eb035e7c5bfb76d7a09034a0d409e3872a1a0ed Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 7 May 2024 22:40:09 -0400 Subject: [PATCH] stronger test case for half mean overflow (#4470) --- test/test_dtype.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_dtype.py b/test/test_dtype.py index e28a19727b..3027e6a755 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -638,10 +638,11 @@ class TestAutoCastType(unittest.TestCase): @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") def test_mean_half_precision_overflow(self): - t = Tensor([60000, 60000, 60000], dtype=dtypes.half, requires_grad=True) + N = 256 + t = Tensor([60000] * N*N, dtype=dtypes.half, requires_grad=True).reshape(N, N) np.testing.assert_allclose(t.mean().numpy(), 60000) t.square().mean().backward() - np.testing.assert_allclose(t.grad.numpy(), [60000 * 2 / 3] * 3) + np.testing.assert_allclose(t.grad.numpy().flatten(), [60000 * 2 / (N*N)] * N*N) class TestImplicitFunctionTypeChange(unittest.TestCase): def test_functions(self):