mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
comments
This commit is contained in:
@@ -36,7 +36,7 @@ class Exp(_UnaryOp):
|
||||
|
||||
bop = BinaryOps.MUL
|
||||
|
||||
# TODO: add Neg?
|
||||
# TODO: add Neg? confirm the optimizer on Sub good enough
|
||||
|
||||
# ************* reduce ops *************
|
||||
|
||||
@@ -98,7 +98,7 @@ class Mul(Function):
|
||||
grad_y = ctx.binary_op(BinaryOps.MUL, x, grad_output) if ctx.needs_input_grad[1] else None
|
||||
return grad_x, grad_y
|
||||
|
||||
# TODO: add Div?
|
||||
# TODO: add Div? is the optimizer on Pow good enough?
|
||||
|
||||
class Pow(Function):
|
||||
def forward(ctx, x, y):
|
||||
@@ -130,15 +130,6 @@ class Expand(Function):
|
||||
in_shape, = ctx.saved_tensors
|
||||
return ctx.reduce_op(ReduceOps.SUM, grad_output, in_shape)
|
||||
|
||||
class Flip(Function):
|
||||
def forward(ctx, x, axis):
|
||||
ctx.save_for_backward(axis)
|
||||
return ctx.movement_op(MovementOps.FLIP, x, axis)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
axis, = ctx.saved_tensors
|
||||
return ctx.movement_op(MovementOps.FLIP, grad_output, axis)
|
||||
|
||||
class Reshape(Function):
|
||||
def forward(ctx, x, shape):
|
||||
ctx.save_for_backward(x.shape)
|
||||
@@ -159,6 +150,8 @@ class Permute(Function):
|
||||
norder = np.argsort(order).tolist()
|
||||
return ctx.movement_op(MovementOps.PERMUTE, grad_output, norder)
|
||||
|
||||
# TODO: merge Slice and Flip into Stride with the 3 arguments
|
||||
|
||||
class Slice(Function):
|
||||
def forward(ctx, x, arg=None):
|
||||
ctx.save_for_backward(x.shape, arg)
|
||||
@@ -169,6 +162,15 @@ class Slice(Function):
|
||||
narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(arg)]
|
||||
return ctx.movement_op(MovementOps.SLICE, grad_output, narg)
|
||||
|
||||
class Flip(Function):
|
||||
def forward(ctx, x, axis):
|
||||
ctx.save_for_backward(axis)
|
||||
return ctx.movement_op(MovementOps.FLIP, x, axis)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
axis, = ctx.saved_tensors
|
||||
return ctx.movement_op(MovementOps.FLIP, grad_output, axis)
|
||||
|
||||
# ************* processing ops *************
|
||||
|
||||
class Conv2D(Function):
|
||||
|
||||
Reference in New Issue
Block a user