cleanups, remove np

This commit is contained in:
George Hotz
2021-11-30 16:22:00 -05:00
parent e28cdfb0cf
commit e59381d0da

View File

@@ -44,7 +44,7 @@ class Log(Function):
class Exp(Function):
def forward(ctx, input):
ret = np.clip(input, -88, 88).exp()
ret = input.clip(-88, 88).exp()
ctx.save_for_backward(ret)
return ret
@@ -52,7 +52,7 @@ class Exp(Function):
ret, = ctx.saved_tensors
return grad_output * ret
# ************* reduce ops *************
# ************* reduce ops (with keepdims=True) *************
class Sum(Function):
def forward(ctx, input, axis):
@@ -60,7 +60,8 @@ class Sum(Function):
return input.sum(axis, keepdims=True)
def backward(ctx, grad_output):
return grad_output.expand(ctx.saved_tensors[0])
shape_input, = ctx.saved_tensors
return grad_output.expand(shape_input)
class Max(Function):
def forward(ctx, inp, axis):
@@ -74,7 +75,7 @@ class Max(Function):
div = ret2.sum(axis=tuple(axis), keepdims=True)
return ret2*grad_output/div.type(input.dtype)
# ************* binary ops *************
# ************* binary ops (with broadcasting) *************
def unbroadcast(out, in_sh):
# adjoint operation to broadcast is sum. Need to sum all axis with 1 = in_sh[i] < out.shape[i]