diff --git a/test/test_randomness.py b/test/test_randomness.py index 435778efb4..5c0c5c8c60 100644 --- a/test/test_randomness.py +++ b/test/test_randomness.py @@ -251,6 +251,9 @@ class TestRandomness(unittest.TestCase): self.assertTrue(normal_test(Tensor.randn)) self.assertTrue(equal_distribution(Tensor.randn, torch.randn, lambda x: np.random.randn(*x))) + def test_randn_device(self): + self.assertEqual(Tensor.randn(3,3,device="CPU").device, "CPU") + @given(strat.sampled_from([dtypes.float, dtypes.float16, dtypes.bfloat16])) @unittest.skipIf(Device.DEFAULT in ["HSA", "AMD"], "bfloat16 local buffer broken in HSA") def test_randn_finite(self, default_float): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 952459ac40..488df0fb4c 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -762,7 +762,7 @@ class Tensor(MathTrait): print(Tensor.randn(2, 3).numpy()) ``` """ - return Tensor.empty(*shape).randn_like(dtype=dtype, requires_grad=requires_grad) + return Tensor.empty(*shape, **kwargs).randn_like(dtype=dtype, requires_grad=requires_grad) @staticmethod def randint(*shape, low=0, high=10, dtype=dtypes.int32, **kwargs) -> Tensor: