mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
cherry binop
This commit is contained in:
@@ -43,7 +43,7 @@ class Sum(Function):
|
||||
input, axis = ctx.saved_tensors
|
||||
if isinstance(axis, int): axis = [axis]
|
||||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
return grad_output.reshape(shape) + np.zeros_like(input)
|
||||
return cherry_binop(grad_output.reshape(shape), np.zeros_like(input), BinaryOps.ADD)
|
||||
|
||||
"""
|
||||
class Max(Function):
|
||||
|
||||
Reference in New Issue
Block a user