mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
support requires_grad
This commit is contained in:
@@ -20,7 +20,7 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categoric
|
||||
losses, accuracies = [], []
|
||||
for i in (t := trange(steps, disable=os.getenv('CI') is not None)):
|
||||
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
||||
x = Tensor(transform(X_train[samp]))
|
||||
x = Tensor(transform(X_train[samp]), requires_grad=False)
|
||||
y = target_transform(Y_train[samp])
|
||||
|
||||
# network
|
||||
|
||||
@@ -158,8 +158,8 @@ class Matmul(Function):
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
grad_input = matmul(grad_output, weight, buffer_new(input.shape), transpose_b=True)
|
||||
grad_weight = matmul(input, grad_output, buffer_new(weight.shape), transpose_a=True)
|
||||
grad_input = matmul(grad_output, weight, buffer_new(input.shape), transpose_b=True) if ctx.needs_input_grad[0] else None
|
||||
grad_weight = matmul(input, grad_output, buffer_new(weight.shape), transpose_a=True) if ctx.needs_input_grad[1] else None
|
||||
return grad_input, grad_weight
|
||||
|
||||
class Conv2D(Function):
|
||||
@@ -191,6 +191,6 @@ class Conv2D(Function):
|
||||
rcout = cout//ctx.groups
|
||||
|
||||
conv_args = H, W, ctx.groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs
|
||||
dw = convdw(x, grad_output, buffer_new((cout, cin, H, W)), conv_args)
|
||||
dx = convdx(w, grad_output, buffer_new((bs, cin_, iy, ix), zero=True), conv_args)
|
||||
dx = convdx(w, grad_output, buffer_new((bs, cin_, iy, ix), zero=True), conv_args) if ctx.needs_input_grad[0] else None
|
||||
dw = convdw(x, grad_output, buffer_new((cout, cin, H, W)), conv_args) if ctx.needs_input_grad[1] else None
|
||||
return dx, dw
|
||||
|
||||
@@ -161,8 +161,9 @@ class Tensor:
|
||||
assert (t0.grad is not None)
|
||||
with ProfileOp(t0._ctx, t0._ctx.__class__.__name__, [t0.grad], backward=True) as po:
|
||||
grads = t0._ctx.backward(t0._ctx, t0.grad.data)
|
||||
po.output = grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
|
||||
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
|
||||
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
|
||||
po.output = [x for x in grads if x is not None] # backward can return None if no required gradient, don't profile it
|
||||
for t, g in zip(t0._ctx.parents, grads):
|
||||
if g is not None and t.requires_grad:
|
||||
assert g.shape == t.shape, \
|
||||
@@ -382,9 +383,10 @@ class Function:
|
||||
# overwrite with passed params
|
||||
for k, v in kwargs.items():
|
||||
setattr(ctx, k, v)
|
||||
ctx.needs_input_grad = [t.requires_grad for t in x]
|
||||
with ProfileOp(ctx, ctx.__class__.__name__, x) as po:
|
||||
ret = Tensor(self.forward(ctx, *[t.data for t in x], **kwargs),
|
||||
device=ctx.device, requires_grad=any(t.requires_grad for t in x))
|
||||
device=ctx.device, requires_grad=any(ctx.needs_input_grad))
|
||||
po.output = [ret]
|
||||
if ret.requires_grad:
|
||||
ret._ctx = ctx
|
||||
|
||||
Reference in New Issue
Block a user