mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix device arg to Tensor.randn (#11194)
* fix device arg to Tensor.randn * simpler test * self.assertEqual
This commit is contained in:
@@ -251,6 +251,9 @@ class TestRandomness(unittest.TestCase):
|
||||
self.assertTrue(normal_test(Tensor.randn))
|
||||
self.assertTrue(equal_distribution(Tensor.randn, torch.randn, lambda x: np.random.randn(*x)))
|
||||
|
||||
def test_randn_device(self):
|
||||
self.assertEqual(Tensor.randn(3,3,device="CPU").device, "CPU")
|
||||
|
||||
@given(strat.sampled_from([dtypes.float, dtypes.float16, dtypes.bfloat16]))
|
||||
@unittest.skipIf(Device.DEFAULT in ["HSA", "AMD"], "bfloat16 local buffer broken in HSA")
|
||||
def test_randn_finite(self, default_float):
|
||||
|
||||
Reference in New Issue
Block a user