fix Tensor.to preserves grad.data (#3636)

This commit is contained in:
chenyu
2024-03-06 21:44:49 -05:00
committed by GitHub
parent d33311ebe0
commit 4552248c84
3 changed files with 6 additions and 1 deletions

View File

@@ -148,6 +148,7 @@ class TestMultiTensor(unittest.TestCase):
out.mean().backward()
#for p in get_parameters(conv): p.grad.realize()
optim.step()
out.numpy()
def test_lr_scheduler_OneCycleLR(self):
from extra.lr_scheduler import OneCycleLR
@@ -520,6 +521,7 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
optim.zero_grad()
out.mean().backward()
optim.step()
out.numpy()
def test_unsynced_backprop_standalone_bn(self):
from extra.lr_scheduler import OneCycleLR

View File

@@ -341,10 +341,13 @@ class TestMoveTensor(unittest.TestCase):
strat.sampled_from([dtypes.float16, dtypes.float32]), strat.sampled_from([True, False, None]))
def test_to_preserves(self, src, dest, dtype, requires_grad):
s = Tensor([1, 2, 3], device=src, dtype=dtype, requires_grad=requires_grad)
if requires_grad: s.sum().backward()
t = s.to(dest)
np.testing.assert_equal(s.numpy(), t.numpy())
assert s.dtype == t.dtype
assert s.requires_grad == t.requires_grad
if requires_grad:
np.testing.assert_equal(s.grad.numpy(), t.grad.numpy())
@given(strat.sampled_from([dtypes.float16, dtypes.float32]), strat.sampled_from([True, False, None]))
def test_shard_preserves(self, dtype, requires_grad):