simpler sum and max

This commit is contained in:
George Hotz
2021-11-30 00:53:27 -05:00
parent c39824bc62
commit 5d60df2b10

View File

@@ -56,27 +56,23 @@ class Exp(Function):
class Sum(Function):
def forward(ctx, input, axis):
ctx.save_for_backward(input, axis)
ctx.save_for_backward(input.shape)
return input.sum(axis, keepdims=True)
def backward(ctx, grad_output):
input, axis = ctx.saved_tensors
shape = [1 if i in axis else input.shape[i] for i in range(len(input.shape))]
return grad_output.reshape(shape).expand(input.shape)
return grad_output.expand(ctx.saved_tensors[0])
class Max(Function):
def forward(ctx, inp, axis):
ret = inp.amax(axis=axis, keepdims=True)
ctx.save_for_backward(inp, axis, ret)
ret = ret.reshape([inp.shape[i] for i in range(len(inp.shape)) if i not in axis])
return ret
def backward(ctx, grad_output):
input, axis, ret = ctx.saved_tensors
shape = [1 if i in axis else input.shape[i] for i in range(len(input.shape))]
ret2 = (input==ret.reshape(shape))
div = ret2.sum(axis=tuple(axis), keepdims=True) if axis is not None else ret2.sum()
return ret2*grad_output.reshape(shape)/div.type(input.dtype)
ret2 = (input==ret)
div = ret2.sum(axis=tuple(axis), keepdims=True)
return ret2*grad_output/div.type(input.dtype)
# ************* binary ops *************