check arg types of Tensor.randint (#4751)

raise TypeError if low, high, dtype are not ints
This commit is contained in:
chenyu
2024-05-27 20:24:10 -04:00
committed by GitHub
parent 16756af13c
commit 53b9081aab
2 changed files with 8 additions and 2 deletions

View File

@@ -121,7 +121,11 @@ class TestRandomness(unittest.TestCase):
def test_randint(self):
self.assertFalse(normal_test(Tensor.randint))
self.assertTrue(equal_distribution(partial(Tensor.randint, low=-2, high=5), numpy_func=lambda x: np.random.randint(low=-2, high=5, size=x)))
self.assertTrue(Tensor.randint(1,device="CLANG").device=="CLANG")
self.assertTrue(Tensor.randint(1, device="CLANG").device=="CLANG")
# check types of args
with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0.1, high=3)
with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0, high=3.5)
with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0, high=3, dtype=dtypes.float32)
def test_normal(self):
self.assertTrue(normal_test(Tensor.normal))