mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
fix return dtype of getitem Tensor indexing (#4158)
the use of sum can auto-upcast the result. fixed by using the data dtype as the acc_dtype
This commit is contained in:
@@ -439,6 +439,13 @@ class TestTypeSpec(unittest.TestCase):
|
||||
assert Tensor([0, 1], dtype=dtype).argmin().dtype == dtypes.int32
|
||||
assert Tensor([0, 1], dtype=dtype).multinomial().dtype == dtypes.int32
|
||||
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints))
|
||||
def test_tensor_indexing_returns_same_dtype(self, data_dtype, indices_dtype):
|
||||
X_data = Tensor.rand(60000, 1, 28, 28, dtype=data_dtype)
|
||||
indices = Tensor.randint(512, high=X_data.shape[0]).cast(indices_dtype)
|
||||
X = X_data[indices]
|
||||
assert X.dtype == X_data.dtype
|
||||
|
||||
class TestTypePromotion(unittest.TestCase):
|
||||
@given(strat.sampled_from(core_dtypes))
|
||||
def test_self_promo_to_self(self, dtype):
|
||||
|
||||
Reference in New Issue
Block a user