This commit is contained in:
George Hotz
2021-11-30 01:01:39 -05:00
parent 38dccb3a2e
commit 9b538629bb

View File

@@ -139,8 +139,7 @@ def reduce_op(ctx, code, code2, inp, axis=None, start="0.0"):
class Sum(Function):
def forward(ctx, input, axis=None):
ctx.save_for_backward(input, axis)
ret = reduce_op(ctx, "out += a", "out", input, axis=axis)
return ret
return reduce_op(ctx, "out += a", "out", input, axis=axis)
def backward(ctx, grad_output):
input, axis = ctx.saved_tensors
@@ -201,7 +200,7 @@ class Add(Function):
def backward(ctx, grad_output):
grad_x, grad_y = grad_output, grad_output
shape_x, shape_y = ctx.saved_tensors
return unbroadcast(ctx, grad_x, shape_x), unbroadcast(ctx, grad_y, shape_y),
return unbroadcast(ctx, grad_x, shape_x), unbroadcast(ctx, grad_y, shape_y)
class Sub(Function):
def forward(ctx, x, y):
@@ -211,7 +210,7 @@ class Sub(Function):
def backward(ctx, grad_output):
grad_x, grad_y = grad_output, unary_op(ctx, '-a', grad_output)
shape_x, shape_y = ctx.saved_tensors
return unbroadcast(ctx, grad_x, shape_x), unbroadcast(ctx, grad_y, shape_y),
return unbroadcast(ctx, grad_x, shape_x), unbroadcast(ctx, grad_y, shape_y)
class Mul(Function):
def forward(ctx, x, y):
@@ -222,7 +221,7 @@ class Mul(Function):
x,y = ctx.saved_tensors
grad_x = binary_op(ctx, 'a*b', y, grad_output)
grad_y = binary_op(ctx, 'a*b', x, grad_output)
return unbroadcast(ctx, grad_x, x.shape), unbroadcast(ctx, grad_y, y.shape),
return unbroadcast(ctx, grad_x, x.shape), unbroadcast(ctx, grad_y, y.shape)
class Pow(Function):
def forward(ctx, x, y):
@@ -235,7 +234,7 @@ class Pow(Function):
binary_op(ctx, 'b * (pow((float)a, (float)(b-1.0)))', x, y))
grad_y = binary_op(ctx, 'a*b', grad_output,
binary_op(ctx, 'pow(a, (float)b) * log(a);', x, y))
return unbroadcast(ctx, grad_x, x.shape), unbroadcast(ctx, grad_y, y.shape),
return unbroadcast(ctx, grad_x, x.shape), unbroadcast(ctx, grad_y, y.shape)
# ************* movement ops *************