mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
fix mean underflow for half tensor (#4377)
* fix mean underflow for half tensor divide only the reduce factor. added unit test and non-nan assertion in resnet training. also added a failed test cast for symbolic shape var * skip for python backend
This commit is contained in:
@@ -176,6 +176,7 @@ def train_resnet():
|
|||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
if i == BENCHMARK:
|
if i == BENCHMARK:
|
||||||
|
assert not math.isnan(loss)
|
||||||
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
|
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
|
||||||
estimated_total_minutes = int(median_step_time * steps_in_train_epoch * epochs / 60)
|
estimated_total_minutes = int(median_step_time * steps_in_train_epoch * epochs / 60)
|
||||||
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
|
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
|
||||||
|
|||||||
@@ -621,9 +621,17 @@ class TestAutoCastType(unittest.TestCase):
|
|||||||
t.reshape(2, 1).expand(2, 10001).max().backward()
|
t.reshape(2, 1).expand(2, 10001).max().backward()
|
||||||
np.testing.assert_allclose(t.grad.numpy(), [1, 0])
|
np.testing.assert_allclose(t.grad.numpy(), [1, 0])
|
||||||
|
|
||||||
|
@unittest.skipIf(Device.DEFAULT=="PYTHON", "very slow")
|
||||||
|
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||||
|
def test_mean_half_precision_underflow(self):
|
||||||
|
N = 10000
|
||||||
|
x = 0.001
|
||||||
|
t = Tensor([[x]], dtype=dtypes.half, requires_grad=True).expand(N, N).contiguous()
|
||||||
|
np.testing.assert_allclose(t.mean(axis=1).numpy(), np.array([x] * N, dtype=np.float16), rtol=1e-3)
|
||||||
|
|
||||||
@unittest.skip("TODO: fix this")
|
@unittest.skip("TODO: fix this")
|
||||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||||
def test_mean_half_precision(self):
|
def test_mean_half_precision_overflow(self):
|
||||||
t = Tensor([60000, 60000, 60000], dtype=dtypes.half, requires_grad=True)
|
t = Tensor([60000, 60000, 60000], dtype=dtypes.half, requires_grad=True)
|
||||||
np.testing.assert_allclose(t.mean().numpy(), 60000)
|
np.testing.assert_allclose(t.mean().numpy(), 60000)
|
||||||
t.square().mean().backward()
|
t.square().mean().backward()
|
||||||
|
|||||||
@@ -32,6 +32,14 @@ class TestTensorVariable(unittest.TestCase):
|
|||||||
ret = t.mean().item()
|
ret = t.mean().item()
|
||||||
assert ret == 1
|
assert ret == 1
|
||||||
|
|
||||||
|
@unittest.skip("symbolic var isn't supported")
|
||||||
|
def test_symbolic_var(self):
|
||||||
|
vv = Variable("a", 1, 10)
|
||||||
|
vv.bind(2)
|
||||||
|
t = Tensor.ones(2, 2).contiguous().reshape(2, vv)
|
||||||
|
ret = t.var().item()
|
||||||
|
assert ret == 0
|
||||||
|
|
||||||
def test_symbolic_mean_2d(self):
|
def test_symbolic_mean_2d(self):
|
||||||
vv = Variable("a", 1, 10)
|
vv = Variable("a", 1, 10)
|
||||||
vv.bind(2)
|
vv.bind(2)
|
||||||
|
|||||||
@@ -926,7 +926,7 @@ class Tensor:
|
|||||||
|
|
||||||
def mean(self, axis=None, keepdim=False):
|
def mean(self, axis=None, keepdim=False):
|
||||||
out = self.sum(axis=axis, keepdim=keepdim)
|
out = self.sum(axis=axis, keepdim=keepdim)
|
||||||
return out.div(prod(self.shape)).mul(prod(out.shape)) if 0 not in out.shape else out
|
return out.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if si != so]))
|
||||||
def var(self, axis=None, keepdim=False, correction=1):
|
def var(self, axis=None, keepdim=False, correction=1):
|
||||||
assert all_int(self.shape), "does not support symbolic shape"
|
assert all_int(self.shape), "does not support symbolic shape"
|
||||||
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
|
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
|
||||||
|
|||||||
Reference in New Issue
Block a user