mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
shard and to should preserve requires_grad (#3224)
dtypes are inferred from underlying lazydata, requires_grad needs to be passed explicitly
This commit is contained in:
@@ -3,8 +3,12 @@ import torch
|
||||
import unittest, copy
|
||||
import mmap
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.helpers import temp
|
||||
from tinygrad.helpers import temp, CI
|
||||
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
|
||||
settings.register_profile("my_profile", max_examples=200, deadline=None)
|
||||
settings.load_profile("my_profile")
|
||||
|
||||
x_init = np.random.randn(1,3).astype(np.float32)
|
||||
U_init = np.random.randn(3,3).astype(np.float32)
|
||||
@@ -315,6 +319,26 @@ class TestTinygrad(unittest.TestCase):
|
||||
assert type(reshaped_item) == type(a), a
|
||||
np.testing.assert_allclose(reshaped_item, a), a
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
|
||||
class TestMoveTensor(unittest.TestCase):
|
||||
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
|
||||
@given(strat.sampled_from([d0, d1]), strat.sampled_from([d0, d1]),
|
||||
strat.sampled_from([dtypes.float16, dtypes.float32]), strat.sampled_from([True, False, None]))
|
||||
def test_to_preserves(self, src, dest, dtype, requires_grad):
|
||||
s = Tensor([1, 2, 3], device=src, dtype=dtype, requires_grad=requires_grad)
|
||||
t = s.to(dest)
|
||||
np.testing.assert_equal(s.numpy(), t.numpy())
|
||||
assert s.dtype == t.dtype
|
||||
assert s.requires_grad == t.requires_grad
|
||||
|
||||
@given(strat.sampled_from([dtypes.float16, dtypes.float32]), strat.sampled_from([True, False, None]))
|
||||
def test_shard_preserves(self, dtype, requires_grad):
|
||||
s = Tensor([1, 2, 3], dtype=dtype, requires_grad=requires_grad)
|
||||
t = s.shard((f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"))
|
||||
np.testing.assert_equal(s.numpy(), t.numpy())
|
||||
assert s.dtype == t.dtype
|
||||
assert s.requires_grad == t.requires_grad
|
||||
|
||||
class TestZeroShapeTensor(unittest.TestCase):
|
||||
def test_shape_stride(self):
|
||||
t = Tensor.rand(3, 2, 0)
|
||||
|
||||
Reference in New Issue
Block a user