From 9e12c1bbba3b2b6415fe6937edf5431f670cd5ed Mon Sep 17 00:00:00 2001 From: George Hotz Date: Thu, 17 Jun 2021 16:50:40 -0700 Subject: [PATCH] cherry binop --- extra/ops_cherry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extra/ops_cherry.py b/extra/ops_cherry.py index b930e125a4..0a5e957198 100644 --- a/extra/ops_cherry.py +++ b/extra/ops_cherry.py @@ -43,7 +43,7 @@ class Sum(Function): input, axis = ctx.saved_tensors if isinstance(axis, int): axis = [axis] shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))] - return grad_output.reshape(shape) + np.zeros_like(input) + return cherry_binop(grad_output.reshape(shape), np.zeros_like(input), BinaryOps.ADD) """ class Max(Function):