Fix bug where Tensor.randn returns inf (#1192)

* fix randn inf bug

* add test

* more compact test

* clarify test purpose
This commit is contained in:
fluffy χατγιρλ
2023-07-08 21:03:46 +02:00
committed by GitHub
parent d9c1d81e99
commit 628ee46627
2 changed files with 8 additions and 1 deletions

View File

@@ -177,7 +177,7 @@ class Tensor:
def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
src = Tensor.rand(2, *shape, **kwargs)
return src[0].mul(2*pi).cos().mul(src[1].log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
return src[0].mul(2*pi).cos().mul(((1 - src[1])).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
@staticmethod
def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor: return ((high-low) * Tensor.rand(*shape, **kwargs)) + low