raise error if setitem tensors have requires_grad (#4575)

* raise error if setitem tensors have requires_grad

working on supporting this, first properly raises error

* NotImplementedError
This commit is contained in:
chenyu
2024-05-13 18:56:47 -04:00
committed by GitHub
parent f7d08bd454
commit 0fa57b8ce9
2 changed files with 20 additions and 0 deletions

View File

@@ -81,5 +81,23 @@ class TestSetitem(unittest.TestCase):
np.testing.assert_allclose(t.numpy(), n)
np.testing.assert_allclose(t.numpy(), [[1,1,1,1,1,1],[2,2,2,2,2,2],[3,3,3,3,3,3],[4,4,4,4,4,4],[5,5,5,5,5,5],[6,6,6,6,6,6]])
class TestWithGrad(unittest.TestCase):
def test_no_requires_grad_works(self):
z = Tensor.rand(8, 8)
x = Tensor.rand(8)
z[:3] = x
def test_set_into_requires_grad(self):
z = Tensor.rand(8, 8, requires_grad=True)
x = Tensor.rand(8)
with self.assertRaises(NotImplementedError):
z[:3] = x
def test_set_with_requires_grad(self):
z = Tensor.rand(8, 8)
x = Tensor.rand(8, requires_grad=True)
with self.assertRaises(NotImplementedError):
z[:3] = x
if __name__ == '__main__':
unittest.main()