mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
Add some more docs (#10634)
* more docs * Add multinomial to ops * better doc
This commit is contained in:
@@ -24,6 +24,7 @@
|
||||
::: tinygrad.Tensor.randn
|
||||
::: tinygrad.Tensor.randn_like
|
||||
::: tinygrad.Tensor.randint
|
||||
::: tinygrad.Tensor.randperm
|
||||
::: tinygrad.Tensor.normal
|
||||
::: tinygrad.Tensor.uniform
|
||||
::: tinygrad.Tensor.scaled_uniform
|
||||
|
||||
@@ -40,6 +40,7 @@
|
||||
::: tinygrad.Tensor.masked_fill
|
||||
::: tinygrad.Tensor.sort
|
||||
::: tinygrad.Tensor.topk
|
||||
::: tinygrad.Tensor.multinomial
|
||||
|
||||
## Neural Network (functional)
|
||||
|
||||
|
||||
@@ -867,12 +867,29 @@ class Tensor(MathTrait):
|
||||
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def randperm(n: int, *, device=None, dtype=dtypes.int32, **kwargs) -> Tensor:
|
||||
def randperm(n:int, device=None, dtype=dtypes.int32, **kwargs) -> Tensor:
|
||||
"""
|
||||
Return a tensor with a random permutation of integers from 0 to n-1.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
print(Tensor.randperm(4).numpy())
|
||||
```
|
||||
"""
|
||||
r = Tensor.rand(n, device=device, **kwargs)
|
||||
_, indices = r.sort()
|
||||
return indices.cast(dtype)
|
||||
|
||||
def multinomial(self:Tensor, num_samples:int = 1, replacement:bool = False) -> Tensor:
|
||||
"""
|
||||
Sample from a multinomial distribution weighted by `self`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.arange(10)
|
||||
print(t.multinomial().numpy())
|
||||
```
|
||||
"""
|
||||
assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
|
||||
assert replacement or num_samples == 1, "no replacement only supports num_samples = 1"
|
||||
weight = self.unsqueeze(0) if self.ndim == 1 else self
|
||||
@@ -1620,6 +1637,7 @@ class Tensor(MathTrait):
|
||||
mask = Tensor([True, False, True, False, False])
|
||||
value = Tensor([-1, -2, -3, -4, -5])
|
||||
print(t.masked_fill(mask, value).numpy())
|
||||
```
|
||||
"""
|
||||
return mask.where(value, self)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user