Tensor.uniform with dtype=int bug fix (#1593)

This commit is contained in:
Jordan Wright
2023-08-26 01:59:53 -04:00
committed by GitHub
parent f702a8f497
commit 25be7f745d
2 changed files with 12 additions and 1 deletions

View File

@@ -179,7 +179,9 @@ class Tensor:
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean
@staticmethod
def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor: return ((high-low) * Tensor.rand(*shape, **kwargs)) + low
def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor:
dtype = kwargs.pop("dtype", Tensor.default_type)
return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low
@staticmethod
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul(math.prod(shape)**-0.5)