From 745e36fda5d9598638ca167734918aa2fcba4827 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 3 Jul 2022 12:41:05 -0700 Subject: [PATCH] mlops cleanup --- tinygrad/mlops.py | 74 +++++++++++++++++++--------------------------- tinygrad/tensor.py | 5 ++-- 2 files changed, 32 insertions(+), 47 deletions(-) diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index fe786557fe..dfbd04aea0 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -5,38 +5,32 @@ from tinygrad.tensor import Function # ************* unary ops ************* -class _UnaryOp(Function): +class ReLU(Function): def forward(ctx, input): ctx.save_for_backward(input) - return input.unary_op(ctx.fop) + return input.unary_op(UnaryOps.RELU) def backward(ctx, grad_output): - input, = ctx.saved_tensors - return input.binary_op(ctx.bop, grad_output) - -class ReLU(_UnaryOp): - fop = UnaryOps.RELU - - def backward(ctx, grad_output): - input, = ctx.saved_tensors - ret = input.unary_op(UnaryOps.SIGN) + ret = ctx.saved_tensors[0].unary_op(UnaryOps.SIGN) ret = ret.unary_op(UnaryOps.RELU) return ret.binary_op(BinaryOps.MUL, grad_output) -class Log(_UnaryOp): - fop = UnaryOps.LOG +class Log(Function): + def forward(ctx, input): + ctx.save_for_backward(input) + return input.unary_op(UnaryOps.LOG) def backward(ctx, grad_output): - input, = ctx.saved_tensors - return grad_output.binary_op(BinaryOps.DIV, input) + return grad_output.binary_op(BinaryOps.DIV, ctx.saved_tensors[0]) -class Exp(_UnaryOp): +class Exp(Function): def forward(ctx, input): ret = input.unary_op(UnaryOps.EXP) ctx.save_for_backward(ret) # we save the output here, not the input return ret - bop = BinaryOps.MUL + def backward(ctx, grad_output): + return ctx.saved_tensors[0].binary_op(BinaryOps.MUL, grad_output) # TODO: add Neg? confirm the optimizer on Sub good enough @@ -44,12 +38,11 @@ class Exp(_UnaryOp): class Sum(Function): def forward(ctx, input, axis=None): - ctx.save_for_backward(input.shape) + ctx.input_shape = input.shape return input.reduce_op(ReduceOps.SUM, reduce_shape(input.shape, axis)) def backward(ctx, grad_output): - shape_input, = ctx.saved_tensors - return grad_output.movement_op(MovementOps.EXPAND, shape_input) + return grad_output.movement_op(MovementOps.EXPAND, ctx.input_shape) class Max(Function): def forward(ctx, input, axis=None): @@ -95,9 +88,8 @@ class Mul(Function): return x.binary_op(BinaryOps.MUL, y) def backward(ctx, grad_output): - x,y = ctx.saved_tensors - grad_x = y.binary_op(BinaryOps.MUL, grad_output) if ctx.needs_input_grad[0] else None - grad_y = x.binary_op(BinaryOps.MUL, grad_output) if ctx.needs_input_grad[1] else None + grad_x = ctx.saved_tensors[1].binary_op(BinaryOps.MUL, grad_output) if ctx.needs_input_grad[0] else None + grad_y = ctx.saved_tensors[0].binary_op(BinaryOps.MUL, grad_output) if ctx.needs_input_grad[1] else None return grad_x, grad_y # TODO: add Div? is the optimizer on Pow good enough? @@ -125,53 +117,46 @@ class Pow(Function): # NOTE: this is sum in reverse class Expand(Function): def forward(ctx, x, shape): - ctx.save_for_backward(x.shape) + ctx.input_shape = x.shape return x.movement_op(MovementOps.EXPAND, shape) def backward(ctx, grad_output): - in_shape, = ctx.saved_tensors - return grad_output.reduce_op(ReduceOps.SUM, in_shape) + return grad_output.reduce_op(ReduceOps.SUM, ctx.input_shape) class Reshape(Function): def forward(ctx, x, shape): - ctx.save_for_backward(x.shape) + ctx.input_shape = x.shape shape = tuple(-prod(x.shape) // prod(shape) if s == -1 else s for s in shape) return x.movement_op(MovementOps.RESHAPE, shape) def backward(ctx, grad_output): - in_shape, = ctx.saved_tensors - return grad_output.movement_op(MovementOps.RESHAPE, in_shape) + return grad_output.movement_op(MovementOps.RESHAPE, ctx.input_shape) class Permute(Function): def forward(ctx, x, order=(1,0)): - ctx.save_for_backward(order) + ctx.input_order = order return x.movement_op(MovementOps.PERMUTE, order) def backward(ctx, grad_output): - order, = ctx.saved_tensors - norder = np.argsort(order).tolist() + norder = np.argsort(ctx.input_order).tolist() return grad_output.movement_op(MovementOps.PERMUTE, norder) # TODO: merge Slice and Flip into Stride with the 3 arguments - class Slice(Function): def forward(ctx, x, arg=None): - ctx.save_for_backward(x.shape, arg) + ctx.narg = [(0-p[0], x.shape[i]-p[0]) for i,p in enumerate(arg)] return x.movement_op(MovementOps.SLICE, arg) def backward(ctx, grad_output): - shape, arg = ctx.saved_tensors - narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(arg)] - return grad_output.movement_op(MovementOps.SLICE, narg) + return grad_output.movement_op(MovementOps.SLICE, ctx.narg) class Flip(Function): def forward(ctx, x, axis): - ctx.save_for_backward(axis) + ctx.axis = axis return x.movement_op(MovementOps.FLIP, axis) def backward(ctx, grad_output): - axis, = ctx.saved_tensors - return grad_output.movement_op(MovementOps.FLIP, axis) + return grad_output.movement_op(MovementOps.FLIP, ctx.axis) # ************* processing ops ************* @@ -182,12 +167,13 @@ class Conv2D(Function): return x.processing_op(ProcessingOps.CONV, w, C) def forward(ctx, x, w, stride=1, groups=1, dilation=1, padding=0): - C = get_conv_args(x.shape, w.shape, stride, groups, dilation=dilation, padding=padding) - ctx.save_for_backward(x,w,C) - return ctx._conv(x, w, C) + ctx.C = get_conv_args(x.shape, w.shape, stride, groups, dilation=dilation, padding=padding) + ctx.save_for_backward(x,w) + return ctx._conv(x, w, ctx.C) def backward(ctx, grad_output): - x, w, C = ctx.saved_tensors + x, w = ctx.saved_tensors + C = ctx.C # conv args from the context dx, dw = None, None if ctx.needs_input_grad[0]: # compute derivative of inputs using ProcessingOps.CONV (this is a transposed conv) xt = grad_output diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index bbe9aa6240..01d7f11851 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -183,6 +183,7 @@ class Tensor: return self[:, :, -padding[2]:self.shape[2]+padding[3], -padding[0]:self.shape[3]+padding[1]] def matmul(x, w): + # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1) bs, groups = prod(x.shape[0:-2]), prod(w.shape[0:-2]) cin, cout = w.shape[-2], w.shape[-1] out_shape_t = tuple(list(x.shape[0:-2])+[cout,-1]) @@ -321,9 +322,7 @@ class Tensor: ret = self.mul(weight.reshape(shape=shp)) if len(weight.shape) == 1 else self.dot(weight) return ret.add(bias.reshape(shape=shp)) - def sequential(self, ll): - for l in ll: self = l(self) - return self + def sequential(self, ll): return functools.reduce(lambda x,f: f(x), ll, self) def layernorm(x, eps=1e-5): y = (x - x.mean(axis=-1, keepdim=True))