support requires_grad

This commit is contained in:
George Hotz
2022-06-06 07:47:31 -07:00
parent 9f9cf076c0
commit 233c71a7ba
3 changed files with 9 additions and 7 deletions

View File

@@ -20,7 +20,7 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categoric
losses, accuracies = [], []
for i in (t := trange(steps, disable=os.getenv('CI') is not None)):
samp = np.random.randint(0, X_train.shape[0], size=(BS))
x = Tensor(transform(X_train[samp]))
x = Tensor(transform(X_train[samp]), requires_grad=False)
y = target_transform(Y_train[samp])
# network

View File

@@ -158,8 +158,8 @@ class Matmul(Function):
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_input = matmul(grad_output, weight, buffer_new(input.shape), transpose_b=True)
grad_weight = matmul(input, grad_output, buffer_new(weight.shape), transpose_a=True)
grad_input = matmul(grad_output, weight, buffer_new(input.shape), transpose_b=True) if ctx.needs_input_grad[0] else None
grad_weight = matmul(input, grad_output, buffer_new(weight.shape), transpose_a=True) if ctx.needs_input_grad[1] else None
return grad_input, grad_weight
class Conv2D(Function):
@@ -191,6 +191,6 @@ class Conv2D(Function):
rcout = cout//ctx.groups
conv_args = H, W, ctx.groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs
dw = convdw(x, grad_output, buffer_new((cout, cin, H, W)), conv_args)
dx = convdx(w, grad_output, buffer_new((bs, cin_, iy, ix), zero=True), conv_args)
dx = convdx(w, grad_output, buffer_new((bs, cin_, iy, ix), zero=True), conv_args) if ctx.needs_input_grad[0] else None
dw = convdw(x, grad_output, buffer_new((cout, cin, H, W)), conv_args) if ctx.needs_input_grad[1] else None
return dx, dw

View File

@@ -161,8 +161,9 @@ class Tensor:
assert (t0.grad is not None)
with ProfileOp(t0._ctx, t0._ctx.__class__.__name__, [t0.grad], backward=True) as po:
grads = t0._ctx.backward(t0._ctx, t0.grad.data)
po.output = grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
po.output = [x for x in grads if x is not None] # backward can return None if no required gradient, don't profile it
for t, g in zip(t0._ctx.parents, grads):
if g is not None and t.requires_grad:
assert g.shape == t.shape, \
@@ -382,9 +383,10 @@ class Function:
# overwrite with passed params
for k, v in kwargs.items():
setattr(ctx, k, v)
ctx.needs_input_grad = [t.requires_grad for t in x]
with ProfileOp(ctx, ctx.__class__.__name__, x) as po:
ret = Tensor(self.forward(ctx, *[t.data for t in x], **kwargs),
device=ctx.device, requires_grad=any(t.requires_grad for t in x))
device=ctx.device, requires_grad=any(ctx.needs_input_grad))
po.output = [ret]
if ret.requires_grad:
ret._ctx = ctx