fix device arg to Tensor.randn (#11194)

* fix device arg to Tensor.randn

* simpler test

* self.assertEqual
This commit is contained in:
nimlgen
2025-07-12 20:51:59 +03:00
committed by GitHub
parent 6283d50224
commit 110cff3f2e
2 changed files with 4 additions and 1 deletions

View File

@@ -251,6 +251,9 @@ class TestRandomness(unittest.TestCase):
self.assertTrue(normal_test(Tensor.randn))
self.assertTrue(equal_distribution(Tensor.randn, torch.randn, lambda x: np.random.randn(*x)))
def test_randn_device(self):
self.assertEqual(Tensor.randn(3,3,device="CPU").device, "CPU")
@given(strat.sampled_from([dtypes.float, dtypes.float16, dtypes.bfloat16]))
@unittest.skipIf(Device.DEFAULT in ["HSA", "AMD"], "bfloat16 local buffer broken in HSA")
def test_randn_finite(self, default_float):