fix multinomial

This commit is contained in:
qazal
2023-12-23 11:41:06 +02:00
parent a27bbd65db
commit b66a06ba67

View File

@@ -229,7 +229,7 @@ class Tensor:
weight = self.unsqueeze(0) if self.ndim == 1 else self
cdf = (cw := weight.cumsum(1)) / cw[:, -1].unsqueeze(1)
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1)
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
indices = ((unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).cast(dtypes.int)).sum(2).permute((1, 0))
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
# ***** toposort and backward pass *****