mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Tensor.linspace raises for dtype.bool (#7649)
also fixed an assert when passing str dtype to randint
This commit is contained in:
@@ -227,11 +227,15 @@ class TestRandomness(unittest.TestCase):
|
||||
|
||||
def test_randint(self):
|
||||
self.assertFalse(normal_test(Tensor.randint))
|
||||
self.assertTrue(equal_distribution(partial(Tensor.randint, low=-2, high=5), numpy_func=lambda x: np.random.randint(low=-2, high=5, size=x)))
|
||||
self.assertTrue(equal_distribution(partial(Tensor.randint, low=-2, high=5),
|
||||
numpy_func=lambda x: np.random.randint(low=-2, high=5, size=x)))
|
||||
self.assertTrue(equal_distribution(partial(Tensor.randint, low=-2, high=5, dtype="int32"),
|
||||
numpy_func=lambda x: np.random.randint(low=-2, high=5, size=x)))
|
||||
self.assertTrue(Tensor.randint(1, device="CLANG").device=="CLANG")
|
||||
# check types of args
|
||||
with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0.1, high=3)
|
||||
with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0, high=3.5)
|
||||
with self.assertRaises(TypeError): Tensor.randint((3, 4), low=1, high=3, dtype="float")
|
||||
with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0, high=3, dtype=dtypes.float32)
|
||||
|
||||
def test_normal(self):
|
||||
|
||||
Reference in New Issue
Block a user