mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix multinomial
This commit is contained in:
@@ -229,7 +229,7 @@ class Tensor:
|
||||
weight = self.unsqueeze(0) if self.ndim == 1 else self
|
||||
cdf = (cw := weight.cumsum(1)) / cw[:, -1].unsqueeze(1)
|
||||
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1)
|
||||
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
|
||||
indices = ((unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).cast(dtypes.int)).sum(2).permute((1, 0))
|
||||
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
|
||||
|
||||
# ***** toposort and backward pass *****
|
||||
|
||||
Reference in New Issue
Block a user