threefry half (#6154)

This commit is contained in:
wozeparrot
2024-08-18 15:23:12 -07:00
committed by GitHub
parent fad1818530
commit 0c5189de25
5 changed files with 39 additions and 16 deletions

View File

@@ -64,16 +64,16 @@ class TestRandomness(unittest.TestCase):
self.assertFalse(normal_test(Tensor.rand))
self.assertTrue(equal_distribution(Tensor.rand, torch.rand, lambda x: np.random.rand(*x)))
@unittest.skipIf(THREEFRY.value, "broken with threefry")
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need bfloat16 support")
def test_rand_half(self):
N = 128
x = Tensor.rand((2, N, N), dtype=dtypes.half)
assert x.dtype == dtypes.half
x = x.numpy()
ones = np.take(x, np.where(x == 1))
zeros = np.take(x, np.where(x == 0))
self.assertTrue(ones.size == 0)
self.assertTrue(zeros.size > 0)
ones = x[x == 1]
zeros = x[x == 0]
assert ones.size == 0
assert zeros.size > 0
equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.float16), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N))
@unittest.skipIf(not THREEFRY.value, "not using threefry")
@@ -149,11 +149,11 @@ class TestRandomness(unittest.TestCase):
lambda x: np.random.uniform(-1, 1, size=x) * math.sqrt(6 / (x[0] + math.prod(x[1:])))))
def test_kaiming_uniform(self):
for shape in [(128, 64, 3, 3), (20, 24)]:
for shape in [(128, 64, 3, 3), (20, 24), (3, 55, 5)]:
self.assertTrue(equal_distribution(Tensor.kaiming_uniform, lambda x: torch.nn.init.kaiming_uniform_(torch.empty(x)), shape=shape))
def test_kaiming_normal(self):
for shape in [(128, 64, 3, 3), (20, 24)]:
for shape in [(128, 64, 3, 3), (20, 24), (3, 55, 5)]:
self.assertTrue(equal_distribution(Tensor.kaiming_normal, lambda x: torch.nn.init.kaiming_normal_(torch.empty(x)), shape=shape))
def test_multinomial(self):