fix return dtype of getitem Tensor indexing (#4158)

the use of sum can auto-upcast the result. fixed by using the data dtype as the acc_dtype
This commit is contained in:
chenyu
2024-04-12 15:55:02 -04:00
committed by GitHub
parent f6c8032e5d
commit d9c5a2b1bb
2 changed files with 9 additions and 2 deletions

View File

@@ -439,6 +439,13 @@ class TestTypeSpec(unittest.TestCase):
assert Tensor([0, 1], dtype=dtype).argmin().dtype == dtypes.int32
assert Tensor([0, 1], dtype=dtype).multinomial().dtype == dtypes.int32
@given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints))
def test_tensor_indexing_returns_same_dtype(self, data_dtype, indices_dtype):
X_data = Tensor.rand(60000, 1, 28, 28, dtype=data_dtype)
indices = Tensor.randint(512, high=X_data.shape[0]).cast(indices_dtype)
X = X_data[indices]
assert X.dtype == X_data.dtype
class TestTypePromotion(unittest.TestCase):
@given(strat.sampled_from(core_dtypes))
def test_self_promo_to_self(self, dtype):