mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
raise error if setitem tensors have requires_grad (#4575)
* raise error if setitem tensors have requires_grad working on supporting this, first properly raises error * NotImplementedError
This commit is contained in:
@@ -81,5 +81,23 @@ class TestSetitem(unittest.TestCase):
|
||||
np.testing.assert_allclose(t.numpy(), n)
|
||||
np.testing.assert_allclose(t.numpy(), [[1,1,1,1,1,1],[2,2,2,2,2,2],[3,3,3,3,3,3],[4,4,4,4,4,4],[5,5,5,5,5,5],[6,6,6,6,6,6]])
|
||||
|
||||
class TestWithGrad(unittest.TestCase):
|
||||
def test_no_requires_grad_works(self):
|
||||
z = Tensor.rand(8, 8)
|
||||
x = Tensor.rand(8)
|
||||
z[:3] = x
|
||||
|
||||
def test_set_into_requires_grad(self):
|
||||
z = Tensor.rand(8, 8, requires_grad=True)
|
||||
x = Tensor.rand(8)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
z[:3] = x
|
||||
|
||||
def test_set_with_requires_grad(self):
|
||||
z = Tensor.rand(8, 8)
|
||||
x = Tensor.rand(8, requires_grad=True)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
z[:3] = x
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user