mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
lower sample count in test_multinomial (#12027)
This commit is contained in:
@@ -323,9 +323,9 @@ class TestRandomness(unittest.TestCase):
|
||||
torch_res = torch_res.unsqueeze(0)
|
||||
for i in range(torch_res.shape[0]):
|
||||
self.assertTrue(equal_distribution(lambda *_: tiny_res[i], lambda _: torch_res[i]))
|
||||
_check_with_torch(w=[0.231, 0., 1., 0.5], num_samples=2000, replacement=True)
|
||||
_check_with_torch(w=[[0.2, 0.8]], num_samples=2000, replacement=True) # 2D but only 1 row
|
||||
_check_with_torch(w=[[0.453, 0., 1., 0.81], [0.1, 0.8, 0., 0.1]], num_samples=2000, replacement=True)
|
||||
_check_with_torch(w=[0.231, 0., 1., 0.5], num_samples=300, replacement=True)
|
||||
_check_with_torch(w=[[0.2, 0.8]], num_samples=300, replacement=True) # 2D but only 1 row
|
||||
_check_with_torch(w=[[0.453, 0., 1., 0.81], [0.1, 0.8, 0., 0.1]], num_samples=300, replacement=True)
|
||||
# no-replacement isn't supported, unless taking only one sample
|
||||
w = [0.1, 0.9]
|
||||
self.assertRaises(AssertionError, lambda: Tensor(w).multinomial(100, replacement=False))
|
||||
|
||||
Reference in New Issue
Block a user