mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix Tensor.to preserves grad.data (#3636)
This commit is contained in:
@@ -148,6 +148,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
out.mean().backward()
|
||||
#for p in get_parameters(conv): p.grad.realize()
|
||||
optim.step()
|
||||
out.numpy()
|
||||
|
||||
def test_lr_scheduler_OneCycleLR(self):
|
||||
from extra.lr_scheduler import OneCycleLR
|
||||
@@ -520,6 +521,7 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
||||
optim.zero_grad()
|
||||
out.mean().backward()
|
||||
optim.step()
|
||||
out.numpy()
|
||||
|
||||
def test_unsynced_backprop_standalone_bn(self):
|
||||
from extra.lr_scheduler import OneCycleLR
|
||||
|
||||
@@ -341,10 +341,13 @@ class TestMoveTensor(unittest.TestCase):
|
||||
strat.sampled_from([dtypes.float16, dtypes.float32]), strat.sampled_from([True, False, None]))
|
||||
def test_to_preserves(self, src, dest, dtype, requires_grad):
|
||||
s = Tensor([1, 2, 3], device=src, dtype=dtype, requires_grad=requires_grad)
|
||||
if requires_grad: s.sum().backward()
|
||||
t = s.to(dest)
|
||||
np.testing.assert_equal(s.numpy(), t.numpy())
|
||||
assert s.dtype == t.dtype
|
||||
assert s.requires_grad == t.requires_grad
|
||||
if requires_grad:
|
||||
np.testing.assert_equal(s.grad.numpy(), t.grad.numpy())
|
||||
|
||||
@given(strat.sampled_from([dtypes.float16, dtypes.float32]), strat.sampled_from([True, False, None]))
|
||||
def test_shard_preserves(self, dtype, requires_grad):
|
||||
|
||||
Reference in New Issue
Block a user