diff --git a/test/test_randomness.py b/test/test_randomness.py index 47b8d02cba..aa742d36d9 100644 --- a/test/test_randomness.py +++ b/test/test_randomness.py @@ -121,7 +121,11 @@ class TestRandomness(unittest.TestCase): def test_randint(self): self.assertFalse(normal_test(Tensor.randint)) self.assertTrue(equal_distribution(partial(Tensor.randint, low=-2, high=5), numpy_func=lambda x: np.random.randint(low=-2, high=5, size=x))) - self.assertTrue(Tensor.randint(1,device="CLANG").device=="CLANG") + self.assertTrue(Tensor.randint(1, device="CLANG").device=="CLANG") + # check types of args + with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0.1, high=3) + with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0, high=3.5) + with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0, high=3, dtype=dtypes.float32) def test_normal(self): self.assertTrue(normal_test(Tensor.normal)) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 8c876cd862..703c748055 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -557,7 +557,9 @@ class Tensor: print(Tensor.randint(2, 3, low=5, high=10).numpy()) ``` """ - assert dtypes.is_int(dtype := kwargs.pop("dtype", dtypes.int32)), f"Unsupported dtype {dtype} for randint" + if not isinstance(low, int) or not isinstance(high, int): raise TypeError(f"{low=} and {high=} must be integers") + dtype = kwargs.pop("dtype", dtypes.int32) + if not dtypes.is_int(dtype): raise TypeError(f"{dtype=} must be int") return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs) @staticmethod