mlops cleanup

This commit is contained in:
George Hotz
2022-07-03 12:41:05 -07:00
parent 93c378dffc
commit 745e36fda5
2 changed files with 32 additions and 47 deletions

View File

@@ -5,38 +5,32 @@ from tinygrad.tensor import Function
# ************* unary ops *************
class _UnaryOp(Function):
class ReLU(Function):
def forward(ctx, input):
ctx.save_for_backward(input)
return input.unary_op(ctx.fop)
return input.unary_op(UnaryOps.RELU)
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return input.binary_op(ctx.bop, grad_output)
class ReLU(_UnaryOp):
fop = UnaryOps.RELU
def backward(ctx, grad_output):
input, = ctx.saved_tensors
ret = input.unary_op(UnaryOps.SIGN)
ret = ctx.saved_tensors[0].unary_op(UnaryOps.SIGN)
ret = ret.unary_op(UnaryOps.RELU)
return ret.binary_op(BinaryOps.MUL, grad_output)
class Log(_UnaryOp):
fop = UnaryOps.LOG
class Log(Function):
def forward(ctx, input):
ctx.save_for_backward(input)
return input.unary_op(UnaryOps.LOG)
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return grad_output.binary_op(BinaryOps.DIV, input)
return grad_output.binary_op(BinaryOps.DIV, ctx.saved_tensors[0])
class Exp(_UnaryOp):
class Exp(Function):
def forward(ctx, input):
ret = input.unary_op(UnaryOps.EXP)
ctx.save_for_backward(ret) # we save the output here, not the input
return ret
bop = BinaryOps.MUL
def backward(ctx, grad_output):
return ctx.saved_tensors[0].binary_op(BinaryOps.MUL, grad_output)
# TODO: add Neg? confirm the optimizer on Sub good enough
@@ -44,12 +38,11 @@ class Exp(_UnaryOp):
class Sum(Function):
def forward(ctx, input, axis=None):
ctx.save_for_backward(input.shape)
ctx.input_shape = input.shape
return input.reduce_op(ReduceOps.SUM, reduce_shape(input.shape, axis))
def backward(ctx, grad_output):
shape_input, = ctx.saved_tensors
return grad_output.movement_op(MovementOps.EXPAND, shape_input)
return grad_output.movement_op(MovementOps.EXPAND, ctx.input_shape)
class Max(Function):
def forward(ctx, input, axis=None):
@@ -95,9 +88,8 @@ class Mul(Function):
return x.binary_op(BinaryOps.MUL, y)
def backward(ctx, grad_output):
x,y = ctx.saved_tensors
grad_x = y.binary_op(BinaryOps.MUL, grad_output) if ctx.needs_input_grad[0] else None
grad_y = x.binary_op(BinaryOps.MUL, grad_output) if ctx.needs_input_grad[1] else None
grad_x = ctx.saved_tensors[1].binary_op(BinaryOps.MUL, grad_output) if ctx.needs_input_grad[0] else None
grad_y = ctx.saved_tensors[0].binary_op(BinaryOps.MUL, grad_output) if ctx.needs_input_grad[1] else None
return grad_x, grad_y
# TODO: add Div? is the optimizer on Pow good enough?
@@ -125,53 +117,46 @@ class Pow(Function):
# NOTE: this is sum in reverse
class Expand(Function):
def forward(ctx, x, shape):
ctx.save_for_backward(x.shape)
ctx.input_shape = x.shape
return x.movement_op(MovementOps.EXPAND, shape)
def backward(ctx, grad_output):
in_shape, = ctx.saved_tensors
return grad_output.reduce_op(ReduceOps.SUM, in_shape)
return grad_output.reduce_op(ReduceOps.SUM, ctx.input_shape)
class Reshape(Function):
def forward(ctx, x, shape):
ctx.save_for_backward(x.shape)
ctx.input_shape = x.shape
shape = tuple(-prod(x.shape) // prod(shape) if s == -1 else s for s in shape)
return x.movement_op(MovementOps.RESHAPE, shape)
def backward(ctx, grad_output):
in_shape, = ctx.saved_tensors
return grad_output.movement_op(MovementOps.RESHAPE, in_shape)
return grad_output.movement_op(MovementOps.RESHAPE, ctx.input_shape)
class Permute(Function):
def forward(ctx, x, order=(1,0)):
ctx.save_for_backward(order)
ctx.input_order = order
return x.movement_op(MovementOps.PERMUTE, order)
def backward(ctx, grad_output):
order, = ctx.saved_tensors
norder = np.argsort(order).tolist()
norder = np.argsort(ctx.input_order).tolist()
return grad_output.movement_op(MovementOps.PERMUTE, norder)
# TODO: merge Slice and Flip into Stride with the 3 arguments
class Slice(Function):
def forward(ctx, x, arg=None):
ctx.save_for_backward(x.shape, arg)
ctx.narg = [(0-p[0], x.shape[i]-p[0]) for i,p in enumerate(arg)]
return x.movement_op(MovementOps.SLICE, arg)
def backward(ctx, grad_output):
shape, arg = ctx.saved_tensors
narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(arg)]
return grad_output.movement_op(MovementOps.SLICE, narg)
return grad_output.movement_op(MovementOps.SLICE, ctx.narg)
class Flip(Function):
def forward(ctx, x, axis):
ctx.save_for_backward(axis)
ctx.axis = axis
return x.movement_op(MovementOps.FLIP, axis)
def backward(ctx, grad_output):
axis, = ctx.saved_tensors
return grad_output.movement_op(MovementOps.FLIP, axis)
return grad_output.movement_op(MovementOps.FLIP, ctx.axis)
# ************* processing ops *************
@@ -182,12 +167,13 @@ class Conv2D(Function):
return x.processing_op(ProcessingOps.CONV, w, C)
def forward(ctx, x, w, stride=1, groups=1, dilation=1, padding=0):
C = get_conv_args(x.shape, w.shape, stride, groups, dilation=dilation, padding=padding)
ctx.save_for_backward(x,w,C)
return ctx._conv(x, w, C)
ctx.C = get_conv_args(x.shape, w.shape, stride, groups, dilation=dilation, padding=padding)
ctx.save_for_backward(x,w)
return ctx._conv(x, w, ctx.C)
def backward(ctx, grad_output):
x, w, C = ctx.saved_tensors
x, w = ctx.saved_tensors
C = ctx.C # conv args from the context
dx, dw = None, None
if ctx.needs_input_grad[0]: # compute derivative of inputs using ProcessingOps.CONV (this is a transposed conv)
xt = grad_output

View File

@@ -183,6 +183,7 @@ class Tensor:
return self[:, :, -padding[2]:self.shape[2]+padding[3], -padding[0]:self.shape[3]+padding[1]]
def matmul(x, w):
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
bs, groups = prod(x.shape[0:-2]), prod(w.shape[0:-2])
cin, cout = w.shape[-2], w.shape[-1]
out_shape_t = tuple(list(x.shape[0:-2])+[cout,-1])
@@ -321,9 +322,7 @@ class Tensor:
ret = self.mul(weight.reshape(shape=shp)) if len(weight.shape) == 1 else self.dot(weight)
return ret.add(bias.reshape(shape=shp))
def sequential(self, ll):
for l in ll: self = l(self)
return self
def sequential(self, ll): return functools.reduce(lambda x,f: f(x), ll, self)
def layernorm(x, eps=1e-5):
y = (x - x.mean(axis=-1, keepdim=True))