From b8d460d203cc39be28890ec6d302b8d4f7ea9050 Mon Sep 17 00:00:00 2001 From: Marcello Fuschi Date: Wed, 15 Nov 2023 20:38:39 +0100 Subject: [PATCH] Add Tensor.multinomial (#2295) * add Tensor.multinomial only with replacement * add support for 2D input in Tensor.multinomial * fix multinomial output shape * allow passing replacement=False to Tensor.multinomial when num_samples=1 * improve tests for Tensor.multinomial * fix edge case in Tensor.multinomial * Tensor.multinomial no more staticmethod --- test/test_randomness.py | 27 +++++++++++++++++++++++++++ tinygrad/tensor.py | 13 +++++++++++++ 2 files changed, 40 insertions(+) diff --git a/test/test_randomness.py b/test/test_randomness.py index cf1b0180a2..c36b03a234 100644 --- a/test/test_randomness.py +++ b/test/test_randomness.py @@ -96,6 +96,33 @@ class TestRandomness(unittest.TestCase): for shape in [(128, 64, 3, 3), (20, 24)]: self.assertTrue(equal_distribution(Tensor.kaiming_normal, lambda x: torch.nn.init.kaiming_normal_(torch.empty(x)), shape=shape)) + def test_multinomial(self): + def _check_with_torch(p, num_samples, replacement): + tiny_res = Tensor(p).multinomial(num_samples, replacement=replacement) + torch_res = torch.tensor(p).multinomial(num_samples, replacement=replacement) + self.assertEqual(tiny_res.shape, torch_res.shape) + if torch_res.ndim == 1: + tiny_res = tiny_res.unsqueeze(0) + torch_res = torch_res.unsqueeze(0) + for i in range(torch_res.shape[0]): + self.assertTrue(equal_distribution(lambda *_: tiny_res[i], lambda _: torch_res[i])) + _check_with_torch(p=[0.231, 0., 1., 0.5], num_samples=2000, replacement=True) + _check_with_torch(p=[[0.2, 0.8]], num_samples=2000, replacement=True) # 2D but only 1 row + _check_with_torch(p=[[0.453, 0., 1., 0.81], [0.1, 0.8, 0., 0.1]], num_samples=2000, replacement=True) + # no-replacement isn't supported, unless taking only one sample + p = [0.1, 0.9] + self.assertRaises(AssertionError, lambda: Tensor(p).multinomial(100, replacement=False)) + tiny_samples = [Tensor(p).multinomial(1, replacement=False).numpy().item() for _ in range(1000)] + torch_samples = [torch.tensor(p).multinomial(1, replacement=False).item() for _ in range(1000)] + self.assertTrue(equal_distribution(lambda *_: Tensor(tiny_samples), lambda _: torch.tensor(torch_samples))) + + def test_multinomial_counterexample(self): + tiny_res = Tensor([0.3, 0.6, 0.1]).multinomial(2000, replacement=True) + torch_res = torch.tensor([0.3, 0.6, 0.1]).multinomial(2000, replacement=True) + self.assertTrue(equal_distribution(lambda *_: tiny_res, lambda _: torch_res)) + torch_res = torch.tensor([0.2, 0.7, 0.1]).multinomial(2000, replacement=True) + self.assertFalse(equal_distribution(lambda *_: tiny_res, lambda _: torch_res)) + def test_conv2d_init(self): params = (128, 256, (3,3)) assert equal_distribution(lambda *_: nn.Conv2d(*params).weight, lambda _: torch.nn.Conv2d(*params).weight.detach()) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 10a280290b..e892ca97a9 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -216,6 +216,19 @@ class Tensor: std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:])) return Tensor.normal(*shape, mean=0.0, std=std, **kwargs) + def multinomial(self: Tensor, num_samples: int, replacement: bool = False) -> Tensor: + assert self.ndim <= 2, "p must be 1 or 2 dim" + assert replacement or num_samples == 1, "supported only with replacement" + p = self.unsqueeze(0) if self.ndim == 1 else self + cdf = p.cumsum(1) + cdf /= cdf[:, -1].unsqueeze(1) + unif_samples = Tensor.rand(num_samples, p.shape[0], 1) + indices = (unif_samples.expand((-1, -1, p.shape[1])) >= cdf).sum(2) + indices = indices.permute((1, 0)) + if self.ndim == 1: + indices = indices.squeeze(0) + return indices.cast(dtypes.int32) + # ***** toposort and backward pass ***** def deepwalk(self): def _deepwalk(node, visited, nodes):