mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
initialize Tensor grad same type as self (#3613)
* initialize Tensor grad same type as self * also test different default float * check dtype + try/finally * don't test_gradient_dtype if f16 is not supported * fix bad merge --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -566,7 +566,21 @@ class TestAutoCastType(unittest.TestCase):
|
||||
assert (Tensor([1, 2], dtype=dtypes.float16) / 2).dtype == dtypes.float16
|
||||
assert (Tensor([1, 2], dtype=dtypes.float16) / 2.0).dtype == dtypes.float16
|
||||
|
||||
class TestImplicitFunctionTypeChange(unittest.TestCase):
|
||||
@unittest.skipIf(not is_dtype_supported(dtypes.float16), "need float16")
|
||||
def test_gradient_dtype(self):
|
||||
for default_dtype in [dtypes.float16, dtypes.float32]:
|
||||
old_default_float = dtypes.default_float
|
||||
try:
|
||||
dtypes.default_float = default_dtype
|
||||
for datatype in [dtypes.float16, dtypes.float32]:
|
||||
a = Tensor([1, 2, 3], dtype=datatype, requires_grad=True)
|
||||
b = (a * 5).sum()
|
||||
b.backward() # if there is dtype mismatch, lazy should assert
|
||||
assert a.grad.dtype == a.dtype
|
||||
np.testing.assert_allclose(a.grad.numpy(), Tensor([5, 5, 5], dtype=datatype).numpy())
|
||||
finally:
|
||||
dtypes.default_float = old_default_float
|
||||
|
||||
def test_functions(self):
|
||||
result = []
|
||||
for func in [
|
||||
|
||||
@@ -343,7 +343,7 @@ class Tensor:
|
||||
|
||||
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
|
||||
# this is "implicit gradient creation"
|
||||
self.grad = Tensor(1.0, device=self.device, requires_grad=False)
|
||||
self.grad = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
|
||||
|
||||
for t0 in reversed(self.deepwalk()):
|
||||
if t0.grad is None: raise RuntimeError("tensor has no grad")
|
||||
|
||||
Reference in New Issue
Block a user