diff --git a/test/unit/test_indexing.py b/test/unit/test_indexing.py index c95e69193f..31ab4ff124 100644 --- a/test/unit/test_indexing.py +++ b/test/unit/test_indexing.py @@ -1000,7 +1000,7 @@ def assert_backward_eq(tensor: Tensor, indexer): def get_set_tensor(indexed: Tensor, indexer): set_size = indexed[indexer].shape set_count = indexed[indexer].numel() - set_tensor = Tensor.randint(set_count, high=set_count).reshape(set_size) #.cast(dtypes.float64) + set_tensor = Tensor.randint(set_count, high=set_count).reshape(set_size).cast(indexed.dtype) return set_tensor @slow