mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
simpler sum and max
This commit is contained in:
@@ -56,27 +56,23 @@ class Exp(Function):
|
||||
|
||||
class Sum(Function):
|
||||
def forward(ctx, input, axis):
|
||||
ctx.save_for_backward(input, axis)
|
||||
ctx.save_for_backward(input.shape)
|
||||
return input.sum(axis, keepdims=True)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, axis = ctx.saved_tensors
|
||||
shape = [1 if i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
return grad_output.reshape(shape).expand(input.shape)
|
||||
return grad_output.expand(ctx.saved_tensors[0])
|
||||
|
||||
class Max(Function):
|
||||
def forward(ctx, inp, axis):
|
||||
ret = inp.amax(axis=axis, keepdims=True)
|
||||
ctx.save_for_backward(inp, axis, ret)
|
||||
ret = ret.reshape([inp.shape[i] for i in range(len(inp.shape)) if i not in axis])
|
||||
return ret
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, axis, ret = ctx.saved_tensors
|
||||
shape = [1 if i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
ret2 = (input==ret.reshape(shape))
|
||||
div = ret2.sum(axis=tuple(axis), keepdims=True) if axis is not None else ret2.sum()
|
||||
return ret2*grad_output.reshape(shape)/div.type(input.dtype)
|
||||
ret2 = (input==ret)
|
||||
div = ret2.sum(axis=tuple(axis), keepdims=True)
|
||||
return ret2*grad_output/div.type(input.dtype)
|
||||
|
||||
# ************* binary ops *************
|
||||
|
||||
|
||||
Reference in New Issue
Block a user