fix: make Tensor.rand produce correct values for float16 (#3654)

* fix: make Tensor.rand produce correct values for float16

Due to precision loss when casting to float16, the data distribution created by custom_random isnt correctly in the interval ]0, 1[, but instead in the interval ]0, 1], which causes the Tensor.randn to incorrectly generate values of infinity.

The solution uses a scaling value to make sure the values stay under 1, when using half precision.

Closes #3611

* update implementation to truncate to closest f16 value to 1

* chore: fix whitespace

* test larger distribution

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
Skosh
2024-03-11 00:48:00 +02:00
committed by GitHub
parent bad6adaf8c
commit e8c350fdac
2 changed files with 11 additions and 1 deletions

View File

@@ -58,6 +58,15 @@ class TestRandomness(unittest.TestCase):
self.assertFalse(normal_test(Tensor.rand))
self.assertTrue(equal_distribution(Tensor.rand, torch.rand, lambda x: np.random.rand(*x)))
def test_rand_half(self):
N = 128
x = Tensor.rand((2, N, N), dtype=dtypes.half).realize().numpy()
ones = np.take(x, np.where(x == 1))
zeros = np.take(x, np.where(x == 0))
self.assertTrue(ones.size == 0)
self.assertTrue(zeros.size > 0)
equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.float16), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N))
def test_randn(self):
self.assertTrue(normal_test(Tensor.randn))
self.assertTrue(equal_distribution(Tensor.randn, torch.randn, lambda x: np.random.randn(*x)))

View File

@@ -1014,5 +1014,6 @@ def custom_random(out:Buffer):
Tensor._seed += 1
if DEBUG >= 2: print(f"*** {out.device} rand seed {Tensor._seed} size {out.size:<15d} dtype {out.dtype}")
rng = np.random.default_rng(Tensor._seed)
rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=out.dtype.np, copy=False)
if out.dtype == dtypes.half: rng_np_buffer = (rng.integers(low=0, high=2047, size=out.size) / 2048).astype(np.half, copy=False)
else: rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=out.dtype.np, copy=False)
out.copyin(rng_np_buffer.data)