mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
mlops cleanup
This commit is contained in:
@@ -5,38 +5,32 @@ from tinygrad.tensor import Function
|
||||
|
||||
# ************* unary ops *************
|
||||
|
||||
class _UnaryOp(Function):
|
||||
class ReLU(Function):
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return input.unary_op(ctx.fop)
|
||||
return input.unary_op(UnaryOps.RELU)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return input.binary_op(ctx.bop, grad_output)
|
||||
|
||||
class ReLU(_UnaryOp):
|
||||
fop = UnaryOps.RELU
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
ret = input.unary_op(UnaryOps.SIGN)
|
||||
ret = ctx.saved_tensors[0].unary_op(UnaryOps.SIGN)
|
||||
ret = ret.unary_op(UnaryOps.RELU)
|
||||
return ret.binary_op(BinaryOps.MUL, grad_output)
|
||||
|
||||
class Log(_UnaryOp):
|
||||
fop = UnaryOps.LOG
|
||||
class Log(Function):
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return input.unary_op(UnaryOps.LOG)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return grad_output.binary_op(BinaryOps.DIV, input)
|
||||
return grad_output.binary_op(BinaryOps.DIV, ctx.saved_tensors[0])
|
||||
|
||||
class Exp(_UnaryOp):
|
||||
class Exp(Function):
|
||||
def forward(ctx, input):
|
||||
ret = input.unary_op(UnaryOps.EXP)
|
||||
ctx.save_for_backward(ret) # we save the output here, not the input
|
||||
return ret
|
||||
|
||||
bop = BinaryOps.MUL
|
||||
def backward(ctx, grad_output):
|
||||
return ctx.saved_tensors[0].binary_op(BinaryOps.MUL, grad_output)
|
||||
|
||||
# TODO: add Neg? confirm the optimizer on Sub good enough
|
||||
|
||||
@@ -44,12 +38,11 @@ class Exp(_UnaryOp):
|
||||
|
||||
class Sum(Function):
|
||||
def forward(ctx, input, axis=None):
|
||||
ctx.save_for_backward(input.shape)
|
||||
ctx.input_shape = input.shape
|
||||
return input.reduce_op(ReduceOps.SUM, reduce_shape(input.shape, axis))
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
shape_input, = ctx.saved_tensors
|
||||
return grad_output.movement_op(MovementOps.EXPAND, shape_input)
|
||||
return grad_output.movement_op(MovementOps.EXPAND, ctx.input_shape)
|
||||
|
||||
class Max(Function):
|
||||
def forward(ctx, input, axis=None):
|
||||
@@ -95,9 +88,8 @@ class Mul(Function):
|
||||
return x.binary_op(BinaryOps.MUL, y)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
x,y = ctx.saved_tensors
|
||||
grad_x = y.binary_op(BinaryOps.MUL, grad_output) if ctx.needs_input_grad[0] else None
|
||||
grad_y = x.binary_op(BinaryOps.MUL, grad_output) if ctx.needs_input_grad[1] else None
|
||||
grad_x = ctx.saved_tensors[1].binary_op(BinaryOps.MUL, grad_output) if ctx.needs_input_grad[0] else None
|
||||
grad_y = ctx.saved_tensors[0].binary_op(BinaryOps.MUL, grad_output) if ctx.needs_input_grad[1] else None
|
||||
return grad_x, grad_y
|
||||
|
||||
# TODO: add Div? is the optimizer on Pow good enough?
|
||||
@@ -125,53 +117,46 @@ class Pow(Function):
|
||||
# NOTE: this is sum in reverse
|
||||
class Expand(Function):
|
||||
def forward(ctx, x, shape):
|
||||
ctx.save_for_backward(x.shape)
|
||||
ctx.input_shape = x.shape
|
||||
return x.movement_op(MovementOps.EXPAND, shape)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
in_shape, = ctx.saved_tensors
|
||||
return grad_output.reduce_op(ReduceOps.SUM, in_shape)
|
||||
return grad_output.reduce_op(ReduceOps.SUM, ctx.input_shape)
|
||||
|
||||
class Reshape(Function):
|
||||
def forward(ctx, x, shape):
|
||||
ctx.save_for_backward(x.shape)
|
||||
ctx.input_shape = x.shape
|
||||
shape = tuple(-prod(x.shape) // prod(shape) if s == -1 else s for s in shape)
|
||||
return x.movement_op(MovementOps.RESHAPE, shape)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
in_shape, = ctx.saved_tensors
|
||||
return grad_output.movement_op(MovementOps.RESHAPE, in_shape)
|
||||
return grad_output.movement_op(MovementOps.RESHAPE, ctx.input_shape)
|
||||
|
||||
class Permute(Function):
|
||||
def forward(ctx, x, order=(1,0)):
|
||||
ctx.save_for_backward(order)
|
||||
ctx.input_order = order
|
||||
return x.movement_op(MovementOps.PERMUTE, order)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
order, = ctx.saved_tensors
|
||||
norder = np.argsort(order).tolist()
|
||||
norder = np.argsort(ctx.input_order).tolist()
|
||||
return grad_output.movement_op(MovementOps.PERMUTE, norder)
|
||||
|
||||
# TODO: merge Slice and Flip into Stride with the 3 arguments
|
||||
|
||||
class Slice(Function):
|
||||
def forward(ctx, x, arg=None):
|
||||
ctx.save_for_backward(x.shape, arg)
|
||||
ctx.narg = [(0-p[0], x.shape[i]-p[0]) for i,p in enumerate(arg)]
|
||||
return x.movement_op(MovementOps.SLICE, arg)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
shape, arg = ctx.saved_tensors
|
||||
narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(arg)]
|
||||
return grad_output.movement_op(MovementOps.SLICE, narg)
|
||||
return grad_output.movement_op(MovementOps.SLICE, ctx.narg)
|
||||
|
||||
class Flip(Function):
|
||||
def forward(ctx, x, axis):
|
||||
ctx.save_for_backward(axis)
|
||||
ctx.axis = axis
|
||||
return x.movement_op(MovementOps.FLIP, axis)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
axis, = ctx.saved_tensors
|
||||
return grad_output.movement_op(MovementOps.FLIP, axis)
|
||||
return grad_output.movement_op(MovementOps.FLIP, ctx.axis)
|
||||
|
||||
# ************* processing ops *************
|
||||
|
||||
@@ -182,12 +167,13 @@ class Conv2D(Function):
|
||||
return x.processing_op(ProcessingOps.CONV, w, C)
|
||||
|
||||
def forward(ctx, x, w, stride=1, groups=1, dilation=1, padding=0):
|
||||
C = get_conv_args(x.shape, w.shape, stride, groups, dilation=dilation, padding=padding)
|
||||
ctx.save_for_backward(x,w,C)
|
||||
return ctx._conv(x, w, C)
|
||||
ctx.C = get_conv_args(x.shape, w.shape, stride, groups, dilation=dilation, padding=padding)
|
||||
ctx.save_for_backward(x,w)
|
||||
return ctx._conv(x, w, ctx.C)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
x, w, C = ctx.saved_tensors
|
||||
x, w = ctx.saved_tensors
|
||||
C = ctx.C # conv args from the context
|
||||
dx, dw = None, None
|
||||
if ctx.needs_input_grad[0]: # compute derivative of inputs using ProcessingOps.CONV (this is a transposed conv)
|
||||
xt = grad_output
|
||||
|
||||
@@ -183,6 +183,7 @@ class Tensor:
|
||||
return self[:, :, -padding[2]:self.shape[2]+padding[3], -padding[0]:self.shape[3]+padding[1]]
|
||||
|
||||
def matmul(x, w):
|
||||
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
|
||||
bs, groups = prod(x.shape[0:-2]), prod(w.shape[0:-2])
|
||||
cin, cout = w.shape[-2], w.shape[-1]
|
||||
out_shape_t = tuple(list(x.shape[0:-2])+[cout,-1])
|
||||
@@ -321,9 +322,7 @@ class Tensor:
|
||||
ret = self.mul(weight.reshape(shape=shp)) if len(weight.shape) == 1 else self.dot(weight)
|
||||
return ret.add(bias.reshape(shape=shp))
|
||||
|
||||
def sequential(self, ll):
|
||||
for l in ll: self = l(self)
|
||||
return self
|
||||
def sequential(self, ll): return functools.reduce(lambda x,f: f(x), ll, self)
|
||||
|
||||
def layernorm(x, eps=1e-5):
|
||||
y = (x - x.mean(axis=-1, keepdim=True))
|
||||
|
||||
Reference in New Issue
Block a user