mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 02:21:40 -05:00
Tensor.uniform with dtype=int bug fix (#1593)
This commit is contained in:
@@ -179,7 +179,9 @@ class Tensor:
|
||||
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean
|
||||
|
||||
@staticmethod
|
||||
def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor: return ((high-low) * Tensor.rand(*shape, **kwargs)) + low
|
||||
def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor:
|
||||
dtype = kwargs.pop("dtype", Tensor.default_type)
|
||||
return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low
|
||||
|
||||
@staticmethod
|
||||
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul(math.prod(shape)**-0.5)
|
||||
|
||||
Reference in New Issue
Block a user