From 8a2c23d3dc7f392b022b89ca5ce94f345b53fa56 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 9 Feb 2026 10:37:08 -0500 Subject: [PATCH] raise RuntimeError for setitem dtype mismatch (#14642) --- test/test_setitem.py | 18 +++++++++++------- tinygrad/tensor.py | 3 ++- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/test/test_setitem.py b/test/test_setitem.py index 82f37f5324..1f6cdc02d6 100644 --- a/test/test_setitem.py +++ b/test/test_setitem.py @@ -13,7 +13,7 @@ class TestSetitem(unittest.TestCase): ((6,6), (slice(2,4), slice(3,5)), 1.0), ((6,6), (3, 4), 1.0), ((6,6), (3, None, 4, None), 1.0), - ((4,4,4,4), (Ellipsis, slice(1,3), slice(None)), Tensor(4)), + ((4,4,4,4), (Ellipsis, slice(1,3), slice(None)), Tensor(4.0)), ((4,4,4,4), (Ellipsis, slice(1,3)), 4), ((4,4,4,4), (2, slice(1,3), None, 1), 4), ((4,4,4,4), (slice(1,3), slice(None), slice(0,4,2)), 4), @@ -50,6 +50,10 @@ class TestSetitem(unittest.TestCase): t[1] = v self.assertEqual(t.dtype, dt) + def test_setitem_dtype_mismatch(self): + t = Tensor.zeros(6, dtype=dtypes.float).contiguous().realize() + with self.assertRaises(RuntimeError): t[2:4] = Tensor([1, 2], dtype=dtypes.int) + def test_setitem_into_noncontiguous(self): t = Tensor.ones(4) with self.assertRaises(RuntimeError): t[1] = 5 @@ -193,23 +197,23 @@ class TestSetitem(unittest.TestCase): def test_setitem_advanced_indexing(self): # Example from https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing - t = Tensor.zeros(10,20,30,40,50).contiguous() + t = Tensor.zeros(10,20,30,40,50, dtype=dtypes.int).contiguous() ind_1 = Tensor([5,3,7,8]) ind_2 = Tensor([[[0],[1],[2]],[[3],[4],[5]]]) v = Tensor.arange(2*3*4*10*30*50).reshape(2,3,4,10,30,50) t[:, ind_1, :, ind_2, :] = v - n = np.zeros((10,20,30,40,50)) + n = np.zeros((10,20,30,40,50), dtype=np.int32) n[:, ind_1.numpy(), :, ind_2.numpy(), :] = v.numpy() - np.testing.assert_allclose(t.numpy(), n) + np.testing.assert_equal(t.numpy(), n) def test_setitem_2d_tensor_indexing(self): - t = Tensor.zeros(2).contiguous() + t = Tensor.zeros(2, dtype=dtypes.int).contiguous() index = Tensor([[0, 1], [1,0]]) v = Tensor.arange(2*2).reshape(2, 2).contiguous() t[index] = v - n = np.zeros((2,)) + n = np.zeros((2,), dtype=np.int32) n[index.numpy()] = v.numpy() - np.testing.assert_allclose(t.numpy(), n) + np.testing.assert_equal(t.numpy(), n) @unittest.skip("slow") def test_setitem_tensor_indexing_fuzz(self): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 6eedeccd9a..86c072a60e 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1279,6 +1279,7 @@ class Tensor(OpMixin): return self._getitem(indices) def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None: + if isinstance(v, Tensor) and v.dtype != self.dtype: raise RuntimeError(f"setitem dtype mismatch: {self.dtype=} != {v.dtype=}") if isinstance(self.device, str) and self.device.startswith("DISK"): self.realize()._getitem(indices).assign(v) return @@ -1291,7 +1292,7 @@ class Tensor(OpMixin): else: # basic setitem self.realize() if not self.uop.is_writable_view(): raise RuntimeError("setitem target must be a writable view backed by a buffer") - res.assign(v.cast(res.dtype)).realize() + res.assign(v).realize() def __delitem__(self, indices) -> None: raise TypeError("Tensor does not support deleting items")