diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index f35cb8cc56..aa077a042e 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -36,7 +36,7 @@ class Exp(_UnaryOp): bop = BinaryOps.MUL -# TODO: add Neg? +# TODO: add Neg? confirm the optimizer on Sub good enough # ************* reduce ops ************* @@ -98,7 +98,7 @@ class Mul(Function): grad_y = ctx.binary_op(BinaryOps.MUL, x, grad_output) if ctx.needs_input_grad[1] else None return grad_x, grad_y -# TODO: add Div? +# TODO: add Div? is the optimizer on Pow good enough? class Pow(Function): def forward(ctx, x, y): @@ -130,15 +130,6 @@ class Expand(Function): in_shape, = ctx.saved_tensors return ctx.reduce_op(ReduceOps.SUM, grad_output, in_shape) -class Flip(Function): - def forward(ctx, x, axis): - ctx.save_for_backward(axis) - return ctx.movement_op(MovementOps.FLIP, x, axis) - - def backward(ctx, grad_output): - axis, = ctx.saved_tensors - return ctx.movement_op(MovementOps.FLIP, grad_output, axis) - class Reshape(Function): def forward(ctx, x, shape): ctx.save_for_backward(x.shape) @@ -159,6 +150,8 @@ class Permute(Function): norder = np.argsort(order).tolist() return ctx.movement_op(MovementOps.PERMUTE, grad_output, norder) +# TODO: merge Slice and Flip into Stride with the 3 arguments + class Slice(Function): def forward(ctx, x, arg=None): ctx.save_for_backward(x.shape, arg) @@ -169,6 +162,15 @@ class Slice(Function): narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(arg)] return ctx.movement_op(MovementOps.SLICE, grad_output, narg) +class Flip(Function): + def forward(ctx, x, axis): + ctx.save_for_backward(axis) + return ctx.movement_op(MovementOps.FLIP, x, axis) + + def backward(ctx, grad_output): + axis, = ctx.saved_tensors + return ctx.movement_op(MovementOps.FLIP, grad_output, axis) + # ************* processing ops ************* class Conv2D(Function):