raise RuntimeError for setitem dtype mismatch (#14642)

This commit is contained in:
chenyu
2026-02-09 10:37:08 -05:00
committed by GitHub
parent 80b0119cef
commit 8a2c23d3dc
2 changed files with 13 additions and 8 deletions

View File

@@ -13,7 +13,7 @@ class TestSetitem(unittest.TestCase):
((6,6), (slice(2,4), slice(3,5)), 1.0),
((6,6), (3, 4), 1.0),
((6,6), (3, None, 4, None), 1.0),
((4,4,4,4), (Ellipsis, slice(1,3), slice(None)), Tensor(4)),
((4,4,4,4), (Ellipsis, slice(1,3), slice(None)), Tensor(4.0)),
((4,4,4,4), (Ellipsis, slice(1,3)), 4),
((4,4,4,4), (2, slice(1,3), None, 1), 4),
((4,4,4,4), (slice(1,3), slice(None), slice(0,4,2)), 4),
@@ -50,6 +50,10 @@ class TestSetitem(unittest.TestCase):
t[1] = v
self.assertEqual(t.dtype, dt)
def test_setitem_dtype_mismatch(self):
t = Tensor.zeros(6, dtype=dtypes.float).contiguous().realize()
with self.assertRaises(RuntimeError): t[2:4] = Tensor([1, 2], dtype=dtypes.int)
def test_setitem_into_noncontiguous(self):
t = Tensor.ones(4)
with self.assertRaises(RuntimeError): t[1] = 5
@@ -193,23 +197,23 @@ class TestSetitem(unittest.TestCase):
def test_setitem_advanced_indexing(self):
# Example from https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
t = Tensor.zeros(10,20,30,40,50).contiguous()
t = Tensor.zeros(10,20,30,40,50, dtype=dtypes.int).contiguous()
ind_1 = Tensor([5,3,7,8])
ind_2 = Tensor([[[0],[1],[2]],[[3],[4],[5]]])
v = Tensor.arange(2*3*4*10*30*50).reshape(2,3,4,10,30,50)
t[:, ind_1, :, ind_2, :] = v
n = np.zeros((10,20,30,40,50))
n = np.zeros((10,20,30,40,50), dtype=np.int32)
n[:, ind_1.numpy(), :, ind_2.numpy(), :] = v.numpy()
np.testing.assert_allclose(t.numpy(), n)
np.testing.assert_equal(t.numpy(), n)
def test_setitem_2d_tensor_indexing(self):
t = Tensor.zeros(2).contiguous()
t = Tensor.zeros(2, dtype=dtypes.int).contiguous()
index = Tensor([[0, 1], [1,0]])
v = Tensor.arange(2*2).reshape(2, 2).contiguous()
t[index] = v
n = np.zeros((2,))
n = np.zeros((2,), dtype=np.int32)
n[index.numpy()] = v.numpy()
np.testing.assert_allclose(t.numpy(), n)
np.testing.assert_equal(t.numpy(), n)
@unittest.skip("slow")
def test_setitem_tensor_indexing_fuzz(self):

View File

@@ -1279,6 +1279,7 @@ class Tensor(OpMixin):
return self._getitem(indices)
def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None:
if isinstance(v, Tensor) and v.dtype != self.dtype: raise RuntimeError(f"setitem dtype mismatch: {self.dtype=} != {v.dtype=}")
if isinstance(self.device, str) and self.device.startswith("DISK"):
self.realize()._getitem(indices).assign(v)
return
@@ -1291,7 +1292,7 @@ class Tensor(OpMixin):
else: # basic setitem
self.realize()
if not self.uop.is_writable_view(): raise RuntimeError("setitem target must be a writable view backed by a buffer")
res.assign(v.cast(res.dtype)).realize()
res.assign(v).realize()
def __delitem__(self, indices) -> None:
raise TypeError("Tensor does not support deleting items")