shard and to should preserve requires_grad (#3224)

dtypes are inferred from underlying lazydata, requires_grad needs to be passed explicitly
This commit is contained in:
chenyu
2024-01-24 00:15:10 -05:00
committed by GitHub
parent 23b084e70a
commit 2f4b3ab1c0
2 changed files with 27 additions and 3 deletions

View File

@@ -3,8 +3,12 @@ import torch
import unittest, copy
import mmap
from tinygrad import Tensor, Device, dtypes
from tinygrad.helpers import temp
from tinygrad.helpers import temp, CI
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
from hypothesis import given, settings, strategies as strat
settings.register_profile("my_profile", max_examples=200, deadline=None)
settings.load_profile("my_profile")
x_init = np.random.randn(1,3).astype(np.float32)
U_init = np.random.randn(3,3).astype(np.float32)
@@ -315,6 +319,26 @@ class TestTinygrad(unittest.TestCase):
assert type(reshaped_item) == type(a), a
np.testing.assert_allclose(reshaped_item, a), a
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
class TestMoveTensor(unittest.TestCase):
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
@given(strat.sampled_from([d0, d1]), strat.sampled_from([d0, d1]),
strat.sampled_from([dtypes.float16, dtypes.float32]), strat.sampled_from([True, False, None]))
def test_to_preserves(self, src, dest, dtype, requires_grad):
s = Tensor([1, 2, 3], device=src, dtype=dtype, requires_grad=requires_grad)
t = s.to(dest)
np.testing.assert_equal(s.numpy(), t.numpy())
assert s.dtype == t.dtype
assert s.requires_grad == t.requires_grad
@given(strat.sampled_from([dtypes.float16, dtypes.float32]), strat.sampled_from([True, False, None]))
def test_shard_preserves(self, dtype, requires_grad):
s = Tensor([1, 2, 3], dtype=dtype, requires_grad=requires_grad)
t = s.shard((f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"))
np.testing.assert_equal(s.numpy(), t.numpy())
assert s.dtype == t.dtype
assert s.requires_grad == t.requires_grad
class TestZeroShapeTensor(unittest.TestCase):
def test_shape_stride(self):
t = Tensor.rand(3, 2, 0)