mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
raise RuntimeError for setitem dtype mismatch (#14642)
This commit is contained in:
@@ -13,7 +13,7 @@ class TestSetitem(unittest.TestCase):
|
||||
((6,6), (slice(2,4), slice(3,5)), 1.0),
|
||||
((6,6), (3, 4), 1.0),
|
||||
((6,6), (3, None, 4, None), 1.0),
|
||||
((4,4,4,4), (Ellipsis, slice(1,3), slice(None)), Tensor(4)),
|
||||
((4,4,4,4), (Ellipsis, slice(1,3), slice(None)), Tensor(4.0)),
|
||||
((4,4,4,4), (Ellipsis, slice(1,3)), 4),
|
||||
((4,4,4,4), (2, slice(1,3), None, 1), 4),
|
||||
((4,4,4,4), (slice(1,3), slice(None), slice(0,4,2)), 4),
|
||||
@@ -50,6 +50,10 @@ class TestSetitem(unittest.TestCase):
|
||||
t[1] = v
|
||||
self.assertEqual(t.dtype, dt)
|
||||
|
||||
def test_setitem_dtype_mismatch(self):
|
||||
t = Tensor.zeros(6, dtype=dtypes.float).contiguous().realize()
|
||||
with self.assertRaises(RuntimeError): t[2:4] = Tensor([1, 2], dtype=dtypes.int)
|
||||
|
||||
def test_setitem_into_noncontiguous(self):
|
||||
t = Tensor.ones(4)
|
||||
with self.assertRaises(RuntimeError): t[1] = 5
|
||||
@@ -193,23 +197,23 @@ class TestSetitem(unittest.TestCase):
|
||||
|
||||
def test_setitem_advanced_indexing(self):
|
||||
# Example from https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
|
||||
t = Tensor.zeros(10,20,30,40,50).contiguous()
|
||||
t = Tensor.zeros(10,20,30,40,50, dtype=dtypes.int).contiguous()
|
||||
ind_1 = Tensor([5,3,7,8])
|
||||
ind_2 = Tensor([[[0],[1],[2]],[[3],[4],[5]]])
|
||||
v = Tensor.arange(2*3*4*10*30*50).reshape(2,3,4,10,30,50)
|
||||
t[:, ind_1, :, ind_2, :] = v
|
||||
n = np.zeros((10,20,30,40,50))
|
||||
n = np.zeros((10,20,30,40,50), dtype=np.int32)
|
||||
n[:, ind_1.numpy(), :, ind_2.numpy(), :] = v.numpy()
|
||||
np.testing.assert_allclose(t.numpy(), n)
|
||||
np.testing.assert_equal(t.numpy(), n)
|
||||
|
||||
def test_setitem_2d_tensor_indexing(self):
|
||||
t = Tensor.zeros(2).contiguous()
|
||||
t = Tensor.zeros(2, dtype=dtypes.int).contiguous()
|
||||
index = Tensor([[0, 1], [1,0]])
|
||||
v = Tensor.arange(2*2).reshape(2, 2).contiguous()
|
||||
t[index] = v
|
||||
n = np.zeros((2,))
|
||||
n = np.zeros((2,), dtype=np.int32)
|
||||
n[index.numpy()] = v.numpy()
|
||||
np.testing.assert_allclose(t.numpy(), n)
|
||||
np.testing.assert_equal(t.numpy(), n)
|
||||
|
||||
@unittest.skip("slow")
|
||||
def test_setitem_tensor_indexing_fuzz(self):
|
||||
|
||||
@@ -1279,6 +1279,7 @@ class Tensor(OpMixin):
|
||||
return self._getitem(indices)
|
||||
|
||||
def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None:
|
||||
if isinstance(v, Tensor) and v.dtype != self.dtype: raise RuntimeError(f"setitem dtype mismatch: {self.dtype=} != {v.dtype=}")
|
||||
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
||||
self.realize()._getitem(indices).assign(v)
|
||||
return
|
||||
@@ -1291,7 +1292,7 @@ class Tensor(OpMixin):
|
||||
else: # basic setitem
|
||||
self.realize()
|
||||
if not self.uop.is_writable_view(): raise RuntimeError("setitem target must be a writable view backed by a buffer")
|
||||
res.assign(v.cast(res.dtype)).realize()
|
||||
res.assign(v).realize()
|
||||
|
||||
def __delitem__(self, indices) -> None:
|
||||
raise TypeError("Tensor does not support deleting items")
|
||||
|
||||
Reference in New Issue
Block a user