diff --git a/docs/tensor/creation.md b/docs/tensor/creation.md index d58a1ea52f..dbc5d7739c 100644 --- a/docs/tensor/creation.md +++ b/docs/tensor/creation.md @@ -24,6 +24,7 @@ ::: tinygrad.Tensor.randn ::: tinygrad.Tensor.randn_like ::: tinygrad.Tensor.randint +::: tinygrad.Tensor.randperm ::: tinygrad.Tensor.normal ::: tinygrad.Tensor.uniform ::: tinygrad.Tensor.scaled_uniform diff --git a/docs/tensor/ops.md b/docs/tensor/ops.md index 465add4a2f..9c54475e74 100644 --- a/docs/tensor/ops.md +++ b/docs/tensor/ops.md @@ -40,6 +40,7 @@ ::: tinygrad.Tensor.masked_fill ::: tinygrad.Tensor.sort ::: tinygrad.Tensor.topk +::: tinygrad.Tensor.multinomial ## Neural Network (functional) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index f842a48229..764a67dc08 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -867,12 +867,29 @@ class Tensor(MathTrait): return Tensor.normal(*shape, mean=0.0, std=std, **kwargs) @staticmethod - def randperm(n: int, *, device=None, dtype=dtypes.int32, **kwargs) -> Tensor: + def randperm(n:int, device=None, dtype=dtypes.int32, **kwargs) -> Tensor: + """ + Return a tensor with a random permutation of integers from 0 to n-1. + + ```python exec="true" source="above" session="tensor" result="python" + Tensor.manual_seed(42) + print(Tensor.randperm(4).numpy()) + ``` + """ r = Tensor.rand(n, device=device, **kwargs) _, indices = r.sort() return indices.cast(dtype) def multinomial(self:Tensor, num_samples:int = 1, replacement:bool = False) -> Tensor: + """ + Sample from a multinomial distribution weighted by `self`. + + ```python exec="true" source="above" session="tensor" result="python" + Tensor.manual_seed(42) + t = Tensor.arange(10) + print(t.multinomial().numpy()) + ``` + """ assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive" assert replacement or num_samples == 1, "no replacement only supports num_samples = 1" weight = self.unsqueeze(0) if self.ndim == 1 else self @@ -1620,6 +1637,7 @@ class Tensor(MathTrait): mask = Tensor([True, False, True, False, False]) value = Tensor([-1, -2, -3, -4, -5]) print(t.masked_fill(mask, value).numpy()) + ``` """ return mask.where(value, self)