initialize Tensor grad same type as self (#3613)

* initialize Tensor grad same type as self

* also test different default float

* check dtype + try/finally

* don't test_gradient_dtype if f16 is not supported

* fix bad merge

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
David Hou
2024-03-22 17:33:18 -07:00
committed by GitHub
parent 8db7a6bbcc
commit fc11808a79
2 changed files with 16 additions and 2 deletions

View File

@@ -566,7 +566,21 @@ class TestAutoCastType(unittest.TestCase):
assert (Tensor([1, 2], dtype=dtypes.float16) / 2).dtype == dtypes.float16
assert (Tensor([1, 2], dtype=dtypes.float16) / 2.0).dtype == dtypes.float16
class TestImplicitFunctionTypeChange(unittest.TestCase):
@unittest.skipIf(not is_dtype_supported(dtypes.float16), "need float16")
def test_gradient_dtype(self):
for default_dtype in [dtypes.float16, dtypes.float32]:
old_default_float = dtypes.default_float
try:
dtypes.default_float = default_dtype
for datatype in [dtypes.float16, dtypes.float32]:
a = Tensor([1, 2, 3], dtype=datatype, requires_grad=True)
b = (a * 5).sum()
b.backward() # if there is dtype mismatch, lazy should assert
assert a.grad.dtype == a.dtype
np.testing.assert_allclose(a.grad.numpy(), Tensor([5, 5, 5], dtype=datatype).numpy())
finally:
dtypes.default_float = old_default_float
def test_functions(self):
result = []
for func in [

View File

@@ -343,7 +343,7 @@ class Tensor:
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
# this is "implicit gradient creation"
self.grad = Tensor(1.0, device=self.device, requires_grad=False)
self.grad = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
for t0 in reversed(self.deepwalk()):
if t0.grad is None: raise RuntimeError("tensor has no grad")