mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
cosmetic
This commit is contained in:
@@ -139,8 +139,7 @@ def reduce_op(ctx, code, code2, inp, axis=None, start="0.0"):
|
||||
class Sum(Function):
|
||||
def forward(ctx, input, axis=None):
|
||||
ctx.save_for_backward(input, axis)
|
||||
ret = reduce_op(ctx, "out += a", "out", input, axis=axis)
|
||||
return ret
|
||||
return reduce_op(ctx, "out += a", "out", input, axis=axis)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, axis = ctx.saved_tensors
|
||||
@@ -201,7 +200,7 @@ class Add(Function):
|
||||
def backward(ctx, grad_output):
|
||||
grad_x, grad_y = grad_output, grad_output
|
||||
shape_x, shape_y = ctx.saved_tensors
|
||||
return unbroadcast(ctx, grad_x, shape_x), unbroadcast(ctx, grad_y, shape_y),
|
||||
return unbroadcast(ctx, grad_x, shape_x), unbroadcast(ctx, grad_y, shape_y)
|
||||
|
||||
class Sub(Function):
|
||||
def forward(ctx, x, y):
|
||||
@@ -211,7 +210,7 @@ class Sub(Function):
|
||||
def backward(ctx, grad_output):
|
||||
grad_x, grad_y = grad_output, unary_op(ctx, '-a', grad_output)
|
||||
shape_x, shape_y = ctx.saved_tensors
|
||||
return unbroadcast(ctx, grad_x, shape_x), unbroadcast(ctx, grad_y, shape_y),
|
||||
return unbroadcast(ctx, grad_x, shape_x), unbroadcast(ctx, grad_y, shape_y)
|
||||
|
||||
class Mul(Function):
|
||||
def forward(ctx, x, y):
|
||||
@@ -222,7 +221,7 @@ class Mul(Function):
|
||||
x,y = ctx.saved_tensors
|
||||
grad_x = binary_op(ctx, 'a*b', y, grad_output)
|
||||
grad_y = binary_op(ctx, 'a*b', x, grad_output)
|
||||
return unbroadcast(ctx, grad_x, x.shape), unbroadcast(ctx, grad_y, y.shape),
|
||||
return unbroadcast(ctx, grad_x, x.shape), unbroadcast(ctx, grad_y, y.shape)
|
||||
|
||||
class Pow(Function):
|
||||
def forward(ctx, x, y):
|
||||
@@ -235,7 +234,7 @@ class Pow(Function):
|
||||
binary_op(ctx, 'b * (pow((float)a, (float)(b-1.0)))', x, y))
|
||||
grad_y = binary_op(ctx, 'a*b', grad_output,
|
||||
binary_op(ctx, 'pow(a, (float)b) * log(a);', x, y))
|
||||
return unbroadcast(ctx, grad_x, x.shape), unbroadcast(ctx, grad_y, y.shape),
|
||||
return unbroadcast(ctx, grad_x, x.shape), unbroadcast(ctx, grad_y, y.shape)
|
||||
|
||||
# ************* movement ops *************
|
||||
|
||||
|
||||
Reference in New Issue
Block a user