cherry binop

This commit is contained in:
George Hotz
2021-06-17 16:50:40 -07:00
parent fcdabea880
commit 9e12c1bbba

View File

@@ -43,7 +43,7 @@ class Sum(Function):
input, axis = ctx.saved_tensors
if isinstance(axis, int): axis = [axis]
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
return grad_output.reshape(shape) + np.zeros_like(input)
return cherry_binop(grad_output.reshape(shape), np.zeros_like(input), BinaryOps.ADD)
"""
class Max(Function):