This commit is contained in:
George Hotz
2022-06-22 09:37:50 -07:00
parent 1d4fb3527e
commit 6847eaf5b6

View File

@@ -36,7 +36,7 @@ class Exp(_UnaryOp):
bop = BinaryOps.MUL
# TODO: add Neg?
# TODO: add Neg? confirm the optimizer on Sub good enough
# ************* reduce ops *************
@@ -98,7 +98,7 @@ class Mul(Function):
grad_y = ctx.binary_op(BinaryOps.MUL, x, grad_output) if ctx.needs_input_grad[1] else None
return grad_x, grad_y
# TODO: add Div?
# TODO: add Div? is the optimizer on Pow good enough?
class Pow(Function):
def forward(ctx, x, y):
@@ -130,15 +130,6 @@ class Expand(Function):
in_shape, = ctx.saved_tensors
return ctx.reduce_op(ReduceOps.SUM, grad_output, in_shape)
class Flip(Function):
def forward(ctx, x, axis):
ctx.save_for_backward(axis)
return ctx.movement_op(MovementOps.FLIP, x, axis)
def backward(ctx, grad_output):
axis, = ctx.saved_tensors
return ctx.movement_op(MovementOps.FLIP, grad_output, axis)
class Reshape(Function):
def forward(ctx, x, shape):
ctx.save_for_backward(x.shape)
@@ -159,6 +150,8 @@ class Permute(Function):
norder = np.argsort(order).tolist()
return ctx.movement_op(MovementOps.PERMUTE, grad_output, norder)
# TODO: merge Slice and Flip into Stride with the 3 arguments
class Slice(Function):
def forward(ctx, x, arg=None):
ctx.save_for_backward(x.shape, arg)
@@ -169,6 +162,15 @@ class Slice(Function):
narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(arg)]
return ctx.movement_op(MovementOps.SLICE, grad_output, narg)
class Flip(Function):
def forward(ctx, x, axis):
ctx.save_for_backward(axis)
return ctx.movement_op(MovementOps.FLIP, x, axis)
def backward(ctx, grad_output):
axis, = ctx.saved_tensors
return ctx.movement_op(MovementOps.FLIP, grad_output, axis)
# ************* processing ops *************
class Conv2D(Function):