From 271446e3eb373575b4978b40dc5b62895d6dce5f Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 21 Sep 2022 11:16:02 -0400 Subject: [PATCH] set requires_grad to None (#387) * set requires_grad to None * some things need gradients * hmm, why was get_parameters filtering --- README.md | 4 ++-- extra/utils.py | 3 ++- test/test_gc.py | 10 +++++----- test/test_ops.py | 2 +- test/test_tensor.py | 12 ++++++------ tinygrad/nn/optim.py | 5 +++++ tinygrad/tensor.py | 7 +++++-- 7 files changed, 26 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 4fe371393d..bf4b87491e 100644 --- a/README.md +++ b/README.md @@ -29,8 +29,8 @@ python3 setup.py develop ```python from tinygrad.tensor import Tensor -x = Tensor.eye(3) -y = Tensor([[2.0,0,-2.0]]) +x = Tensor.eye(3, requires_grad=True) +y = Tensor([[2.0,0,-2.0]], requires_grad=True) z = y.matmul(x).sum() z.backward() diff --git a/extra/utils.py b/extra/utils.py index ea3fa40e0b..dcd73a6215 100644 --- a/extra/utils.py +++ b/extra/utils.py @@ -23,10 +23,11 @@ def fetch(url): os.rename(fp+".tmp", fp) return dat +# TODO: move this to optim.py? def get_parameters(obj): parameters = [] if isinstance(obj, Tensor): - if obj.requires_grad: parameters.append(obj) + parameters.append(obj) elif isinstance(obj, list) or isinstance(obj, tuple): for x in obj: parameters.extend(get_parameters(x)) diff --git a/test/test_gc.py b/test/test_gc.py index da76aeaaab..f857e0d107 100644 --- a/test/test_gc.py +++ b/test/test_gc.py @@ -9,22 +9,22 @@ def tensors_allocated(): class TestGC(unittest.TestCase): def test_gc(self): - a = Tensor.zeros(4,4) - b = Tensor.zeros(4,4) + a = Tensor.zeros(4, 4, requires_grad=True) + b = Tensor.zeros(4, 4, requires_grad=True) (a*b).mean().backward() assert(tensors_allocated() > 0) del a,b assert(tensors_allocated() == 0) def test_gc_complex(self): - a = Tensor.zeros(4,4) - b = Tensor.zeros(4,4) + a = Tensor.zeros(4, 4, requires_grad=True) + b = Tensor.zeros(4, 4, requires_grad=True) assert(tensors_allocated() == 2) (a*b).mean().backward() assert(tensors_allocated() == 4) del b assert(tensors_allocated() == 2) - b = Tensor.zeros(4,4) + b = Tensor.zeros(4, 4, requires_grad=True) print(tensors_allocated()) (a*b).mean().backward() print(tensors_allocated()) diff --git a/test/test_ops.py b/test/test_ops.py index 9eb66facc9..722d6c29ac 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -14,7 +14,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=1e-6, rtol=1e-3, grad_ato else: ts = [torch.tensor((np.random.random(size=x).astype(np.float32)+a)*b, requires_grad=True) for x in shps] - tst = [Tensor(x.detach().numpy()) for x in ts] + tst = [Tensor(x.detach().numpy(), requires_grad=True) for x in ts] out = torch_fxn(*ts) ret = tinygrad_fxn(*tst) diff --git a/test/test_tensor.py b/test/test_tensor.py index 1914d64caa..65d579bb30 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -22,9 +22,9 @@ class TestTinygrad(unittest.TestCase): def test_backward_pass(self): def test_tinygrad(): - x = Tensor(x_init) - W = Tensor(W_init) - m = Tensor(m_init, requires_grad=False) + x = Tensor(x_init, requires_grad=True) + W = Tensor(W_init, requires_grad=True) + m = Tensor(m_init) out = x.dot(W).relu() out = out.logsoftmax() out = out.mul(m).add(m).sum() @@ -46,9 +46,9 @@ class TestTinygrad(unittest.TestCase): def test_backward_pass_diamond_model(self): def test_tinygrad(): - u = Tensor(U_init) - v = Tensor(V_init) - w = Tensor(W_init) + u = Tensor(U_init, requires_grad=True) + v = Tensor(V_init, requires_grad=True) + w = Tensor(W_init, requires_grad=True) x = u.mul(v).relu() y = u.mul(w).relu() out = x.add(y).mul(y).relu() diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index 9d99f9041a..a835be2141 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -3,6 +3,11 @@ from tinygrad.tensor import Tensor class Optimizer: def __init__(self, params): + # if it's None, but being put into an optimizer, set it to True + for x in params: + if x.requires_grad is None: + x.requires_grad = True + self.params = [x for x in params if x.requires_grad] def zero_grad(self): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d6790230b7..cd98b30a96 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -14,7 +14,7 @@ class Tensor: # TODO: remove no_init when uniform is late bind training, no_grad, no_init = False, False, False - def __init__(self, data, device=Device.DEFAULT, requires_grad=True): + def __init__(self, data, device=Device.DEFAULT, requires_grad=None): if isinstance(data, list): data = np.array(data, dtype=np.float32) elif isinstance(data, LazyBuffer) and data.device != device: @@ -32,7 +32,10 @@ class Tensor: # tensors have gradients, buffers do not self.grad : Optional[Tensor] = None - self.requires_grad = requires_grad + + # NOTE: this can be in three states. False and None: no gradient, True: gradient + # None (the default) will be updated to True if it's put in an optimizer + self.requires_grad : Optional[bool] = requires_grad # internal variables used for autograd graph construction self._ctx : Optional[Function] = None