diff --git a/test/test_ops.py b/test/test_ops.py index 64daff191e..d7e9e0fab9 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -240,7 +240,11 @@ class TestOps(unittest.TestCase): helper_test_op([], lambda: torch.linspace(5, 10, 30), lambda: Tensor.linspace(5, 10, 30), forward_only=True) helper_test_op([], lambda: torch.linspace(-5.5, 5.5, 10), lambda: Tensor.linspace(-5.5, 5.5, 10), forward_only=True) helper_test_op([], lambda: torch.linspace(5.5, -5.5, 10), lambda: Tensor.linspace(5.5, -5.5, 10), forward_only=True) - helper_test_op([], lambda: torch.linspace(5, 10, 3, dtype=torch.int32), lambda: Tensor.linspace(5, 10, 3, dtype=dtypes.int32), forward_only=True) + helper_test_op([], lambda: torch.linspace(5, 10, 3, dtype=torch.int32), lambda: Tensor.linspace(5, 10, 3, dtype="int32"), forward_only=True) + helper_test_op([], lambda: torch.linspace(5, 10, 20, dtype=torch.int32), lambda: Tensor.linspace(5, 10, 20, dtype="int32"), forward_only=True) + helper_test_op([], lambda: torch.linspace(5, -5, 20, dtype=torch.int32), lambda: Tensor.linspace(5, -5, 20, dtype="int32"), forward_only=True) + self.helper_test_exception([], lambda: torch.linspace(5, 10, 3, dtype=torch.bool), lambda: Tensor.linspace(5, 10, 3, dtype="bool"), + expected=(RuntimeError, ValueError)) self.helper_test_exception([], lambda: torch.linspace(1, 2, -1), lambda: Tensor.linspace(1, 2, -1), expected=(RuntimeError, ValueError)) def test_sum_fake(self): diff --git a/test/test_randomness.py b/test/test_randomness.py index 4258b45b30..4670731174 100644 --- a/test/test_randomness.py +++ b/test/test_randomness.py @@ -227,11 +227,15 @@ class TestRandomness(unittest.TestCase): def test_randint(self): self.assertFalse(normal_test(Tensor.randint)) - self.assertTrue(equal_distribution(partial(Tensor.randint, low=-2, high=5), numpy_func=lambda x: np.random.randint(low=-2, high=5, size=x))) + self.assertTrue(equal_distribution(partial(Tensor.randint, low=-2, high=5), + numpy_func=lambda x: np.random.randint(low=-2, high=5, size=x))) + self.assertTrue(equal_distribution(partial(Tensor.randint, low=-2, high=5, dtype="int32"), + numpy_func=lambda x: np.random.randint(low=-2, high=5, size=x))) self.assertTrue(Tensor.randint(1, device="CLANG").device=="CLANG") # check types of args with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0.1, high=3) with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0, high=3.5) + with self.assertRaises(TypeError): Tensor.randint((3, 4), low=1, high=3, dtype="float") with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0, high=3, dtype=dtypes.float32) def test_normal(self): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 15df5e5bfe..7d3824d44a 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -628,7 +628,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method ``` """ if steps < 0: raise ValueError("number of steps must be non-negative") - dtype = kwargs.pop("dtype", dtypes.default_float) + if (dtype := to_dtype(kwargs.pop("dtype", dtypes.default_float))) == dtypes.bool: raise ValueError("linspace with bool dtype is not supported") if steps == 1: return Tensor([start], dtype=dtype, **kwargs) return (start + Tensor.arange(steps, **kwargs) * ((stop - start) / (steps - 1))).cast(dtype) @@ -754,7 +754,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method ``` """ if not isinstance(low, int) or not isinstance(high, int): raise TypeError(f"{low=} and {high=} must be integers") - dtype = kwargs.pop("dtype", dtypes.int32) + dtype = to_dtype(kwargs.pop("dtype", dtypes.int32)) if not dtypes.is_int(dtype): raise TypeError(f"{dtype=} must be int") return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)