From 0fa57b8ce9ce4844d2dc7a5d4dfcfcabda695d98 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 13 May 2024 18:56:47 -0400 Subject: [PATCH] raise error if setitem tensors have requires_grad (#4575) * raise error if setitem tensors have requires_grad working on supporting this, first properly raises error * NotImplementedError --- test/test_setitem.py | 18 ++++++++++++++++++ tinygrad/tensor.py | 2 ++ 2 files changed, 20 insertions(+) diff --git a/test/test_setitem.py b/test/test_setitem.py index 7e1aaabb89..d28f9eb936 100644 --- a/test/test_setitem.py +++ b/test/test_setitem.py @@ -81,5 +81,23 @@ class TestSetitem(unittest.TestCase): np.testing.assert_allclose(t.numpy(), n) np.testing.assert_allclose(t.numpy(), [[1,1,1,1,1,1],[2,2,2,2,2,2],[3,3,3,3,3,3],[4,4,4,4,4,4],[5,5,5,5,5,5],[6,6,6,6,6,6]]) +class TestWithGrad(unittest.TestCase): + def test_no_requires_grad_works(self): + z = Tensor.rand(8, 8) + x = Tensor.rand(8) + z[:3] = x + + def test_set_into_requires_grad(self): + z = Tensor.rand(8, 8, requires_grad=True) + x = Tensor.rand(8) + with self.assertRaises(NotImplementedError): + z[:3] = x + + def test_set_with_requires_grad(self): + z = Tensor.rand(8, 8) + x = Tensor.rand(8, requires_grad=True) + with self.assertRaises(NotImplementedError): + z[:3] = x + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index e897621b34..8ab65c9af0 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -823,6 +823,8 @@ class Tensor: assert all(lb.st.contiguous for lb in self.lazydata.lbs), "setitem target needs to be contiguous" if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor") if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype) + if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported") + assign_to = self.realize().__getitem__(indices) # NOTE: contiguous to prevent const folding. v = v.cast(assign_to.dtype)._broadcast_to(broadcast_shape(assign_to.shape, v.shape)).contiguous()